8345512: Remove wrapper functions for intrinsics in PQC algorithms

Reviewed-by: weijun
This commit is contained in:
Ben Perez 2024-12-04 22:01:10 +00:00
parent 8d19a560d0
commit f904480a49
3 changed files with 48 additions and 427 deletions

View file

@ -25,7 +25,6 @@
package sun.security.provider;
import jdk.internal.vm.annotation.IntrinsicCandidate;
import sun.security.provider.SHA3.SHAKE128;
import sun.security.provider.SHA3.SHAKE256;
@ -527,13 +526,13 @@ public class ML_DSA {
int tOffset = j*4;
int vOffset = (i*320) + (j*5);
t1[i][tOffset] = (v[vOffset] & 0xFF) +
((v[vOffset+1] << 8) & 0x3FF);
((v[vOffset+1] << 8) & 0x3FF);
t1[i][tOffset+1] = ((v[vOffset+1] >> 2) & 0x3F) +
((v[vOffset+2] << 6) & 0x3FF);
((v[vOffset+2] << 6) & 0x3FF);
t1[i][tOffset+2] = ((v[vOffset+2] >> 4) & 0xF) +
((v[vOffset+3] << 4) & 0x3FF);
((v[vOffset+3] << 4) & 0x3FF);
t1[i][tOffset+3] = ((v[vOffset+3] >> 6) & 0x3) +
((v[vOffset+4] << 2) & 0x3FF);
((v[vOffset+4] << 2) & 0x3FF);
}
}
return t1;
@ -875,8 +874,8 @@ public class ML_DSA {
rawOfs = 0;
}
tmp = (rawAij[rawOfs] & 0xFF) +
((rawAij[rawOfs + 1] & 0xFF) << 8) +
((rawAij[rawOfs + 2] & 0x7F) << 16);
((rawAij[rawOfs + 1] & 0xFF) << 8) +
((rawAij[rawOfs + 2] & 0x7F) << 16);
rawOfs += 3;
if (tmp < ML_DSA_Q) {
aij[ofs] = tmp;
@ -981,7 +980,7 @@ public class ML_DSA {
int multiplier = (gamma2 == 95232 ? 22 : 8);
for (int i = 0; i < mlDsa_k; i++) {
ML_DSA.mlDsaDecomposePoly(input[i], lowPart[i],
highPart[i], gamma2 * 2, multiplier);
highPart[i], gamma2 * 2, multiplier);
}
}
@ -1032,12 +1031,6 @@ public class ML_DSA {
*/
public static int[] mlDsaNtt(int[] coeffs) {
implMlDsaAlmostNttJava(coeffs);
implMlDsaMontMulByConstantJava(coeffs, MONT_R_MOD_Q);
return coeffs;
}
static void implMlDsaAlmostNttJava(int[] coeffs) {
int dimension = ML_DSA_N;
int m = 0;
for (int l = dimension / 2; l > 0; l /= 2) {
@ -1050,15 +1043,11 @@ public class ML_DSA {
m++;
}
}
}
public static int[] mlDsaInverseNtt(int[] coeffs) {
implMlDsaAlmostInverseNttJava(coeffs);
implMlDsaMontMulByConstantJava(coeffs, MONT_DIM_INVERSE);
montMulByConstant(coeffs, MONT_R_MOD_Q);
return coeffs;
}
static void implMlDsaAlmostInverseNttJava(int[] coeffs) {
public static int[] mlDsaInverseNtt(int[] coeffs) {
int dimension = ML_DSA_N;
int m = 0;
for (int l = 1; l < dimension; l *= 2) {
@ -1067,11 +1056,13 @@ public class ML_DSA {
int tmp = coeffs[j];
coeffs[j] = (tmp + coeffs[j + l]);
coeffs[j + l] = montMul(tmp - coeffs[j + l],
MONT_ZETAS_FOR_INVERSE_NTT[m]);
MONT_ZETAS_FOR_INVERSE_NTT[m]);
}
m++;
}
}
montMulByConstant(coeffs, MONT_DIM_INVERSE);
return coeffs;
}
void mlDsaVectorNtt(int[][] vector) {
@ -1086,22 +1077,13 @@ public class ML_DSA {
}
}
//Todo
public static void mlDsaNttMultiply(int[] res, int[] coeffs1, int[] coeffs2) {
implMlDsaNttMultJava(res, coeffs1, coeffs2);
}
static void implMlDsaNttMultJava(int[] product, int[] coeffs1, int[] coeffs2) {
public static void mlDsaNttMultiply(int[] product, int[] coeffs1, int[] coeffs2) {
for (int i = 0; i < ML_DSA_N; i++) {
product[i] = montMul(coeffs1[i], toMont(coeffs2[i]));
}
}
public static void montMulByConstant(int[] coeffs, int constant) {
implMlDsaMontMulByConstantJava(coeffs, constant);
}
static void implMlDsaMontMulByConstantJava(int[] coeffs, int constant) {
for (int i = 0; i < ML_DSA_N; i++) {
coeffs[i] = montMul((coeffs[i]), constant);
}
@ -1109,17 +1091,6 @@ public class ML_DSA {
public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
implMlDsaDecomposePoly(input, lowPart, highPart, twoGamma2, multiplier);
}
@IntrinsicCandidate
static void implMlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
decomposePolyJava(input, lowPart, highPart, twoGamma2, multiplier);
}
static void decomposePolyJava(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
for (int m = 0; m < ML_DSA_N; m++) {
int rplus = input[m];
rplus = rplus - ((rplus + 5373807) >> 23) * ML_DSA_Q;