8310813: Simplify and modernize equals, hashCode, and compareTo for BigInteger

Reviewed-by: rriggs, redestad, rgiulietti
This commit is contained in:
Pavel Rappo 2024-01-11 21:48:58 +00:00
parent 4ea7b36447
commit 49e6121347
6 changed files with 511 additions and 27 deletions

View file

@ -0,0 +1,89 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.math;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 3, time = 5)
@Measurement(iterations = 3, time = 5)
@Fork(value = 3)
public class BigIntegerCompareTo {
public enum Group {S, M, L}
@Param({"S", "M", "L"})
private Group group;
private static final int MAX_LENGTH = Arrays.stream(Group.values())
.mapToInt(p -> getNumbersOfBits(p).length)
.max()
.getAsInt();
private BigInteger[] numbers;
@Setup
public void setup() {
int[] nBits = getNumbersOfBits(group);
numbers = new BigInteger[2 * MAX_LENGTH];
for (int i = 0; i < MAX_LENGTH; i++) {
var p = Shared.createPair(nBits[i % nBits.length]);
numbers[2 * i] = p.x();
numbers[2 * i + 1] = p.y();
}
}
private static int[] getNumbersOfBits(Group p) {
// the below arrays were derived from stats gathered from running tests in
// the security area, which is the biggest client of BigInteger in JDK
return switch (p) {
case S -> new int[]{0, 1, 2, 11, 12, 13, 14, 15, 16, 17, 18, 19};
case M -> new int[]{255, 256, 512};
case L -> new int[]{1023, 1024, 1534, 1535, 1536};
};
}
@Benchmark
public void testCompareTo(Blackhole bh) {
for (int i = 0; i < numbers.length; i += 2)
bh.consume(numbers[i].compareTo(numbers[i + 1]));
}
}

View file

@ -0,0 +1,89 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.math;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 3, time = 5)
@Measurement(iterations = 3, time = 5)
@Fork(value = 3)
public class BigIntegerEquals {
public enum Group {S, M, L}
@Param({"S", "M", "L"})
private Group group;
private static final int MAX_LENGTH = Arrays.stream(Group.values())
.mapToInt(p -> getNumbersOfBits(p).length)
.max()
.getAsInt();
private BigInteger[] numbers;
@Setup
public void setup() {
int[] nBits = getNumbersOfBits(group);
numbers = new BigInteger[2 * MAX_LENGTH];
for (int i = 0; i < MAX_LENGTH; i++) {
var p = Shared.createPair(nBits[i % nBits.length]);
numbers[2 * i] = p.x();
numbers[2 * i + 1] = p.y();
}
}
private static int[] getNumbersOfBits(Group p) {
// the below arrays were derived from stats gathered from running tests in
// the security area, which is the biggest client of BigInteger in JDK
return switch (p) {
case S -> new int[]{1, 46};
case M -> new int[]{129, 130, 251, 252, 253, 254, 255, 256};
case L -> new int[]{382, 383, 384, 445, 446, 447, 448, 519, 520, 521};
};
}
@Benchmark
public void testEquals(Blackhole bh) {
for (int i = 0; i < numbers.length; i += 2)
bh.consume(numbers[i].equals(numbers[i + 1]));
}
}

View file

@ -0,0 +1,87 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.math;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Warmup(iterations = 3, time = 5)
@Measurement(iterations = 3, time = 5)
@Fork(value = 3)
public class BigIntegerHashCode {
public enum Group {S, M, L}
@Param({"S", "M", "L"})
private Group group;
private static final int MAX_LENGTH = Arrays.stream(Group.values())
.mapToInt(p -> getNumbersOfBits(p).length)
.max()
.getAsInt();
private BigInteger[] numbers;
@Setup
public void setup() {
int[] nBits = getNumbersOfBits(group);
numbers = new BigInteger[MAX_LENGTH];
for (int i = 0; i < MAX_LENGTH; i++) {
numbers[i] = Shared.createSingle(nBits[i % nBits.length]);
}
}
private static int[] getNumbersOfBits(Group p) {
// the below arrays were derived from stats gathered from running tests in
// the security area, which is the biggest client of BigInteger in JDK
return switch (p) {
case S -> new int[]{2, 7, 13, 64};
case M -> new int[]{256, 384, 511, 512, 521, 767, 768};
case L -> new int[]{1024, 1025, 2047, 2048, 2049, 3072, 4096, 5120, 6144};
};
}
@Benchmark
public void testHashCode(Blackhole bh) {
for (var n : numbers)
bh.consume(n.hashCode());
}
}

View file

@ -0,0 +1,157 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.openjdk.bench.java.math;
import java.math.BigInteger;
import java.util.Random;
///////////////////////////////////////////////////////////////////////////////
// THIS IS NOT A BENCHMARK
///////////////////////////////////////////////////////////////////////////////
public final class Shared {
// General note
// ============
//
// Isn't there a simple way to get a BigInteger of the specified number
// of bits of magnitude? It does not seem like it.
//
// We cannot create a BigInteger of the specified number of bytes,
// directly and *cheaply*. This constructor does not do what you
// might think it does:
//
// BigInteger(int numBits, Random rnd)
//
// The only real direct option we have is this constructor:
//
// BigInteger(int bitLength, int certainty, Random rnd)
//
// But even with certainty == 0, it is not cheap. So, create the
// number with the closest number of bytes and then shift right
// the excess bits.
private Shared() {
throw new AssertionError("This is a utility class");
}
//
// Creates a pair of same sign numbers x and y that minimally differ in
// magnitude.
//
// More formally: x.bitLength() == nBits and x.signum() == y.signum()
// and either
//
// * y.bitLength() == nBits, and
// * x.testBit(0) != y.testBit(0)
//
// or
//
// * y.bitLength() == nBits + 1
//
// By construction, such numbers are unequal to each other, but the
// difference in magnitude is minimal. That way, the comparison
// methods, such as equals and compareTo, are forced to examine
// the _complete_ number representation.
//
// Assumptions on BigInteger mechanics
// ===================================
//
// 1. bigLength() is not consulted with for short-circuiting; if it is,
// then we have a problem with nBits={0,1}
// 2. new BigInteger(0, new byte[]{0}) and new BigInteger(1, new byte[]{1})
// are not canonicalized to BigInteger.ZERO and BigInteger.ONE,
// respectively; if they are, then internal optimizations might be
// possible (BigInteger is not exactly a value-based class).
// 3. Comparison and equality are checked from the most significant bit
// to the least significant bit, not the other way around (for
// comparison it seems natural, but not for equality). If any
// of those are checked in the opposite direction, then the check
// might short-circuit.
//
public static Pair createPair(int nBits) {
if (nBits < 0) {
throw new IllegalArgumentException(String.valueOf(nBits));
} else if (nBits == 0) {
var zero = new BigInteger(nBits, new byte[0]);
var one = new BigInteger(/* positive */ 1, new byte[]{1});
return new Pair(zero, one);
} else if (nBits == 1) {
var one = new BigInteger(/* positive */ 1, new byte[]{1});
var two = new BigInteger(/* positive */ 1, new byte[]{2});
return new Pair(one, two);
}
int nBytes = (nBits + 7) / 8;
var r = new Random();
var bytes = new byte[nBytes];
r.nextBytes(bytes);
// Create a BigInteger of the exact bit length by:
// 1. ensuring that the most significant bit is set so that
// no leading zeros are truncated, and
// 2. explicitly specifying signum, so it's not calculated from
// the passed bytes, which must represent magnitude only
bytes[0] |= (byte) 0b1000_0000;
var x = new BigInteger(/* positive */ 1, bytes)
.shiftRight(nBytes * 8 - nBits);
var y = x.flipBit(0);
// do not rely on the assert statement in benchmark
if (x.bitLength() != nBits)
throw new AssertionError(x.bitLength() + ", " + nBits);
return new Pair(x, y);
}
public record Pair(BigInteger x, BigInteger y) {
public Pair {
if (x.signum() == -y.signum()) // if the pair comprises positive and negative
throw new IllegalArgumentException("x.signum()=" + x.signum()
+ ", y=signum()=" + y.signum());
if (y.bitLength() - x.bitLength() > 1)
throw new IllegalArgumentException("x.bitLength()=" + x.bitLength()
+ ", y.bitLength()=" + y.bitLength());
}
}
public static BigInteger createSingle(int nBits) {
if (nBits < 0) {
throw new IllegalArgumentException(String.valueOf(nBits));
}
if (nBits == 0) {
return new BigInteger(nBits, new byte[0]);
}
int nBytes = (nBits + 7) / 8;
var r = new Random();
var bytes = new byte[nBytes];
r.nextBytes(bytes);
// Create a BigInteger of the exact bit length by:
// 1. ensuring that the most significant bit is set so that
// no leading zeros are truncated, and
// 2. explicitly specifying signum, so it's not calculated from
// the passed bytes, which must represent magnitude only
bytes[0] |= (byte) 0b1000_0000;
var x = new BigInteger(/* positive */ 1, bytes)
.shiftRight(nBytes * 8 - nBits);
if (x.bitLength() != nBits)
throw new AssertionError(x.bitLength() + ", " + nBits);
return x;
}
}