8277175: Add a parallel multiply method to BigInteger

Reviewed-by: psandoz
This commit is contained in:
Dr Heinz M. Kabutz 2022-02-11 18:49:04 +00:00 committed by Paul Sandoz
parent 0786ddb471
commit 83ffbd2e7a
4 changed files with 601 additions and 19 deletions

View file

@ -36,6 +36,9 @@ import java.io.ObjectStreamField;
import java.util.Arrays;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ThreadLocalRandom;
import jdk.internal.math.DoubleConsts;
@ -1581,7 +1584,30 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
* @return {@code this * val}
*/
public BigInteger multiply(BigInteger val) {
return multiply(val, false);
return multiply(val, false, false, 0);
}
/**
* Returns a BigInteger whose value is {@code (this * val)}.
* When both {@code this} and {@code val} are large, typically
* in the thousands of bits, parallel multiply might be used.
* This method returns the exact same mathematical result as
* {@link #multiply}.
*
* @implNote This implementation may offer better algorithmic
* performance when {@code val == this}.
*
* @implNote Compared to {@link #multiply}, an implementation's
* parallel multiplication algorithm would typically use more
* CPU resources to compute the result faster, and may do so
* with a slight increase in memory consumption.
*
* @param val value to be multiplied by this BigInteger.
* @return {@code this * val}
* @see #multiply
*/
public BigInteger parallelMultiply(BigInteger val) {
return multiply(val, false, true, 0);
}
/**
@ -1590,16 +1616,17 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
*
* @param val value to be multiplied by this BigInteger.
* @param isRecursion whether this is a recursive invocation
* @param parallel whether the multiply should be done in parallel
* @return {@code this * val}
*/
private BigInteger multiply(BigInteger val, boolean isRecursion) {
private BigInteger multiply(BigInteger val, boolean isRecursion, boolean parallel, int depth) {
if (val.signum == 0 || signum == 0)
return ZERO;
int xlen = mag.length;
if (val == this && xlen > MULTIPLY_SQUARE_THRESHOLD) {
return square();
return square(true, parallel, depth);
}
int ylen = val.mag.length;
@ -1677,7 +1704,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
}
}
return multiplyToomCook3(this, val);
return multiplyToomCook3(this, val, parallel, depth);
}
}
}
@ -1844,6 +1871,88 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
}
}
@SuppressWarnings("serial")
private abstract static sealed class RecursiveOp extends RecursiveTask<BigInteger> {
/**
* The threshold until when we should continue forking recursive ops
* if parallel is true. This threshold is only relevant for Toom Cook 3
* multiply and square.
*/
private static final int PARALLEL_FORK_DEPTH_THRESHOLD =
calculateMaximumDepth(ForkJoinPool.getCommonPoolParallelism());
private static final int calculateMaximumDepth(int parallelism) {
return 32 - Integer.numberOfLeadingZeros(parallelism);
}
final boolean parallel;
/**
* The current recursing depth. Since it is a logarithmic algorithm,
* we do not need an int to hold the number.
*/
final byte depth;
private RecursiveOp(boolean parallel, int depth) {
this.parallel = parallel;
this.depth = (byte) depth;
}
private static int getParallelForkDepthThreshold() {
if (Thread.currentThread() instanceof ForkJoinWorkerThread fjwt) {
return calculateMaximumDepth(fjwt.getPool().getParallelism());
}
else {
return PARALLEL_FORK_DEPTH_THRESHOLD;
}
}
protected RecursiveTask<BigInteger> forkOrInvoke() {
if (parallel && depth <= getParallelForkDepthThreshold()) fork();
else invoke();
return this;
}
@SuppressWarnings("serial")
private static final class RecursiveMultiply extends RecursiveOp {
private final BigInteger a;
private final BigInteger b;
public RecursiveMultiply(BigInteger a, BigInteger b, boolean parallel, int depth) {
super(parallel, depth);
this.a = a;
this.b = b;
}
@Override
public BigInteger compute() {
return a.multiply(b, true, parallel, depth);
}
}
@SuppressWarnings("serial")
private static final class RecursiveSquare extends RecursiveOp {
private final BigInteger a;
public RecursiveSquare(BigInteger a, boolean parallel, int depth) {
super(parallel, depth);
this.a = a;
}
@Override
public BigInteger compute() {
return a.square(true, parallel, depth);
}
}
private static RecursiveTask<BigInteger> multiply(BigInteger a, BigInteger b, boolean parallel, int depth) {
return new RecursiveMultiply(a, b, parallel, depth).forkOrInvoke();
}
private static RecursiveTask<BigInteger> square(BigInteger a, boolean parallel, int depth) {
return new RecursiveSquare(a, parallel, depth).forkOrInvoke();
}
}
/**
* Multiplies two BigIntegers using a 3-way Toom-Cook multiplication
* algorithm. This is a recursive divide-and-conquer algorithm which is
@ -1872,7 +1981,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
* LNCS #4547. Springer, Madrid, Spain, June 21-22, 2007.
*
*/
private static BigInteger multiplyToomCook3(BigInteger a, BigInteger b) {
private static BigInteger multiplyToomCook3(BigInteger a, BigInteger b, boolean parallel, int depth) {
int alen = a.mag.length;
int blen = b.mag.length;
@ -1896,16 +2005,20 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
BigInteger v0, v1, v2, vm1, vinf, t1, t2, tm1, da1, db1;
v0 = a0.multiply(b0, true);
depth++;
var v0_task = RecursiveOp.multiply(a0, b0, parallel, depth);
da1 = a2.add(a0);
db1 = b2.add(b0);
vm1 = da1.subtract(a1).multiply(db1.subtract(b1), true);
var vm1_task = RecursiveOp.multiply(da1.subtract(a1), db1.subtract(b1), parallel, depth);
da1 = da1.add(a1);
db1 = db1.add(b1);
v1 = da1.multiply(db1, true);
var v1_task = RecursiveOp.multiply(da1, db1, parallel, depth);
v2 = da1.add(a2).shiftLeft(1).subtract(a0).multiply(
db1.add(b2).shiftLeft(1).subtract(b0), true);
vinf = a2.multiply(b2, true);
db1.add(b2).shiftLeft(1).subtract(b0), true, parallel, depth);
vinf = a2.multiply(b2, true, parallel, depth);
v0 = v0_task.join();
vm1 = vm1_task.join();
v1 = v1_task.join();
// The algorithm requires two divisions by 2 and one by 3.
// All divisions are known to be exact, that is, they do not produce
@ -2071,7 +2184,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
* @return <code>this<sup>2</sup></code>
*/
private BigInteger square() {
return square(false);
return square(false, false, 0);
}
/**
@ -2081,7 +2194,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
* @param isRecursion whether this is a recursive invocation
* @return <code>this<sup>2</sup></code>
*/
private BigInteger square(boolean isRecursion) {
private BigInteger square(boolean isRecursion, boolean parallel, int depth) {
if (signum == 0) {
return ZERO;
}
@ -2103,7 +2216,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
}
}
return squareToomCook3();
return squareToomCook3(parallel, depth);
}
}
}
@ -2237,7 +2350,7 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
* that has better asymptotic performance than the algorithm used in
* squareToLen or squareKaratsuba.
*/
private BigInteger squareToomCook3() {
private BigInteger squareToomCook3(boolean parallel, int depth) {
int len = mag.length;
// k is the size (in ints) of the lower-order slices.
@ -2254,13 +2367,17 @@ public class BigInteger extends Number implements Comparable<BigInteger> {
a0 = getToomSlice(k, r, 2, len);
BigInteger v0, v1, v2, vm1, vinf, t1, t2, tm1, da1;
v0 = a0.square(true);
depth++;
var v0_fork = RecursiveOp.square(a0, parallel, depth);
da1 = a2.add(a0);
vm1 = da1.subtract(a1).square(true);
var vm1_fork = RecursiveOp.square(da1.subtract(a1), parallel, depth);
da1 = da1.add(a1);
v1 = da1.square(true);
vinf = a2.square(true);
v2 = da1.add(a2).shiftLeft(1).subtract(a0).square(true);
var v1_fork = RecursiveOp.square(da1, parallel, depth);
vinf = a2.square(true, parallel, depth);
v2 = da1.add(a2).shiftLeft(1).subtract(a0).square(true, parallel, depth);
v0 = v0_fork.join();
vm1 = vm1_fork.join();
v1 = v1_fork.join();
// The algorithm requires two divisions by 2 and one by 3.
// All divisions are known to be exact, that is, they do not produce