8345512: Remove wrapper functions for intrinsics in PQC algorithms

Reviewed-by: weijun
This commit is contained in:
Ben Perez 2024-12-04 22:01:10 +00:00
parent 8d19a560d0
commit f904480a49
3 changed files with 48 additions and 427 deletions

View file

@ -71,249 +71,6 @@ public final class ML_KEM {
-1599, -709, -789, -1317, -57, 1049, -584 -1599, -709, -789, -1317, -57, 1049, -584
}; };
private static final short[] MONT_ZETAS_FOR_VECTOR_NTT_ARR = new short[]{
// level 0
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
-758, -758, -758, -758, -758, -758, -758, -758,
// level 1
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-359, -359, -359, -359, -359, -359, -359, -359,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
// level 2
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
287, 287, 287, 287, 287, 287, 287, 287,
287, 287, 287, 287, 287, 287, 287, 287,
287, 287, 287, 287, 287, 287, 287, 287,
287, 287, 287, 287, 287, 287, 287, 287,
202, 202, 202, 202, 202, 202, 202, 202,
202, 202, 202, 202, 202, 202, 202, 202,
202, 202, 202, 202, 202, 202, 202, 202,
202, 202, 202, 202, 202, 202, 202, 202,
// level 3
-171, -171, -171, -171, -171, -171, -171, -171,
-171, -171, -171, -171, -171, -171, -171, -171,
622, 622, 622, 622, 622, 622, 622, 622,
622, 622, 622, 622, 622, 622, 622, 622,
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
182, 182, 182, 182, 182, 182, 182, 182,
182, 182, 182, 182, 182, 182, 182, 182,
962, 962, 962, 962, 962, 962, 962, 962,
962, 962, 962, 962, 962, 962, 962, 962,
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
// level 4
573, 573, 573, 573, 573, 573, 573, 573,
-1325, -1325, -1325, -1325, -1325, -1325, -1325, -1325,
264, 264, 264, 264, 264, 264, 264, 264,
383, 383, 383, 383, 383, 383, 383, 383,
-829, -829, -829, -829, -829, -829, -829, -829,
1458, 1458, 1458, 1458, 1458, 1458, 1458, 1458,
-1602, -1602, -1602, -1602, -1602, -1602, -1602, -1602,
-130, -130, -130, -130, -130, -130, -130, -130,
-681, -681, -681, -681, -681, -681, -681, -681,
1017, 1017, 1017, 1017, 1017, 1017, 1017, 1017,
732, 732, 732, 732, 732, 732, 732, 732,
608, 608, 608, 608, 608, 608, 608, 608,
-1542, -1542, -1542, -1542, -1542, -1542, -1542, -1542,
411, 411, 411, 411, 411, 411, 411, 411,
-205, -205, -205, -205, -205, -205, -205, -205,
-1571, -1571, -1571, -1571, -1571, -1571, -1571, -1571,
// level 5
1223, 1223, 1223, 1223, 652, 652, 652, 652,
-552, -552, -552, -552, 1015, 1015, 1015, 1015,
-1293, -1293, -1293, -1293, 1491, 1491, 1491, 1491,
-282, -282, -282, -282, -1544, -1544, -1544, -1544,
516, 516, 516, 516, -8, -8, -8, -8,
-320, -320, -320, -320, -666, -666, -666, -666,
1711, 1711, 1711, 1711, -1162, -1162, -1162, -1162,
126, 126, 126, 126, 1469, 1469, 1469, 1469,
-853, -853, -853, -853, -90, -90, -90, -90,
-271, -271, -271, -271, 830, 830, 830, 830,
107, 107, 107, 107, -1421, -1421, -1421, -1421,
-247, -247, -247, -247, -951, -951, -951, -951,
-398, -398, -398, -398, 961, 961, 961, 961,
-1508, -1508, -1508, -1508, -725, -725, -725, -725,
448, 448, 448, 448, -1065, -1065, -1065, -1065,
677, 677, 677, 677, -1275, -1275, -1275, -1275,
// level 6
-1103, -1103, 430, 430, 555, 555, 843, 843,
-1251, -1251, 871, 871, 1550, 1550, 105, 105,
422, 422, 587, 587, 177, 177, -235, -235,
-291, -291, -460, -460, 1574, 1574, 1653, 1653,
-246, -246, 778, 778, 1159, 1159, -147, -147,
-777, -777, 1483, 1483, -602, -602, 1119, 1119,
-1590, -1590, 644, 644, -872, -872, 349, 349,
418, 418, 329, 329, -156, -156, -75, -75,
817, 817, 1097, 1097, 603, 603, 610, 610,
1322, 1322, -1285, -1285, -1465, -1465, 384, 384,
-1215, -1215, -136, -136, 1218, 1218, -1335, -1335,
-874, -874, 220, 220, -1187, -1187, 1670, 1670,
-1185, -1185, -1530, -1530, -1278, -1278, 794, 794,
-1510, -1510, -854, -854, -870, -870, 478, 478,
-108, -108, -308, -308, 996, 996, 991, 991,
958, 958, -1460, -1460, 1522, 1522, 1628, 1628
};
private static final short[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT_ARR = new short[]{
// level 0
-1628, -1628, -1522, -1522, 1460, 1460, -958, -958,
-991, -991, -996, -996, 308, 308, 108, 108,
-478, -478, 870, 870, 854, 854, 1510, 1510,
-794, -794, 1278, 1278, 1530, 1530, 1185, 1185,
1659, 1659, 1187, 1187, -220, -220, 874, 874,
1335, 1335, -1218, -1218, 136, 136, 1215, 1215,
-384, -384, 1465, 1465, 1285, 1285, -1322, -1322,
-610, -610, -603, -603, -1097, -1097, -817, -817,
75, 75, 156, 156, -329, -329, -418, -418,
-349, -349, 872, 872, -644, -644, 1590, 1590,
-1119, -1119, 602, 602, -1483, -1483, 777, 777,
147, 147, -1159, -1159, -778, -778, 246, 246,
-1653, -1653, -1574, -1574, 460, 460, 291, 291,
235, 235, -177, -177, -587, -587, -422, -422,
-105, -105, -1550, -1550, -871, -871, 1251, 1251,
-843, -843, -555, -555, -430, -430, 1103, 1103,
// level 1
1275, 1275, 1275, 1275, -677, -677, -677, -677,
1065, 1065, 1065, 1065, -448, -448, -448, -448,
725, 725, 725, 725, 1508, 1508, 1508, 1508,
-961, -961, -961, -961, 398, 398, 398, 398,
951, 951, 951, 951, 247, 247, 247, 247,
1421, 1421, 1421, 1421, -107, -107, -107, -107,
-830, -830, -830, -830, 271, 271, 271, 271,
90, 90, 90, 90, 853, 853, 853, 853,
-1469, -1469, -1469, -1469, -126, -126, -126, -126,
1162, 1162, 1162, 1162, 1618, 1618, 1618, 1618,
666, 666, 666, 666, 320, 320, 320, 320,
8, 8, 8, 8, -516, -516, -516, -516,
1544, 1544, 1544, 1544, 282, 282, 282, 282,
-1491, -1491, -1491, -1491, 1293, 1293, 1293, 1293,
-1015, -1015, -1015, -1015, 552, 552, 552, 552,
-652, -652, -652, -652, -1223, -1223, -1223, -1223,
// level 2
1571, 1571, 1571, 1571, 1571, 1571, 1571, 1571,
205, 205, 205, 205, 205, 205, 205, 205,
-411, -411, -411, -411, -411, -411, -411, -411,
1542, 1542, 1542, 1542, 1542, 1542, 1542, 1542,
-608, -608, -608, -608, -608, -608, -608, -608,
-732, -732, -732, -732, -732, -732, -732, -732,
-1017, -1017, -1017, -1017, -1017, -1017, -1017, -1017,
681, 681, 681, 681, 681, 681, 681, 681,
130, 130, 130, 130, 130, 130, 130, 130,
1602, 1602, 1602, 1602, 1602, 1602, 1602, 1602,
-1458, -1458, -1458, -1458, -1458, -1458, -1458, -1458,
829, 829, 829, 829, 829, 829, 829, 829,
-383, -383, -383, -383, -383, -383, -383, -383,
-264, -264, -264, -264, -264, -264, -264, -264,
1325, 1325, 1325, 1325, 1325, 1325, 1325, 1325,
-573, -573, -573, -573, -573, -573, -573, -573,
// level 3
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
-962, -962, -962, -962, -962, -962, -962, -962,
-962, -962, -962, -962, -962, -962, -962, -962,
-182, -182, -182, -182, -182, -182, -182, -182,
-182, -182, -182, -182, -182, -182, -182, -182,
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
-622, -622, -622, -622, -622, -622, -622, -622,
-622, -622, -622, -622, -622, -622, -622, -622,
171, 171, 171, 171, 171, 171, 171, 171,
171, 171, 171, 171, 171, 171, 171, 171,
// level 4
-202, -202, -202, -202, -202, -202, -202, -202,
-202, -202, -202, -202, -202, -202, -202, -202,
-202, -202, -202, -202, -202, -202, -202, -202,
-202, -202, -202, -202, -202, -202, -202, -202,
-287, -287, -287, -287, -287, -287, -287, -287,
-287, -287, -287, -287, -287, -287, -287, -287,
-287, -287, -287, -287, -287, -287, -287, -287,
-287, -287, -287, -287, -287, -287, -287, -287,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
// level 5
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
359, 359, 359, 359, 359, 359, 359, 359,
// level 6
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758,
758, 758, 758, 758, 758, 758, 758, 758
};
private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{ private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{
-1003, 1003, 222, -222, -1107, 1107, 172, -172, -1003, 1003, 222, -222, -1107, 1107, 172, -172,
-42, 42, 620, -620, 1497, -1497, -1649, 1649, -42, 42, 620, -620, 1497, -1497, -1649, 1649,
@ -333,25 +90,6 @@ public final class ML_KEM {
-1317, 1317, -57, 57, 1049, -1049, -584, 584 -1317, 1317, -57, 57, 1049, -1049, -584, 584
}; };
private static final short[] MONT_ZETAS_FOR_VECTOR_NTT_MULT_ARR = new short[]{
-1103, 1103, 430, -430, 555, -555, 843, -843,
-1251, 1251, 871, -871, 1550, -1550, 105, -105,
422, -422, 587, -587, 177, -177, -235, 235,
-291, 291, -460, 460, 1574, -1574, 1653, -1653,
-246, 246, 778, -778, 1159, -1159, -147, 147,
-777, 777, 1483, -1483, -602, 602, 1119, -1119,
-1590, 1590, 644, -644, -872, 872, 349, -349,
418, -418, 329, -329, -156, 156, -75, 75,
817, -817, 1097, -1097, 603, -603, 610, -610,
1322, -1322, -1285, 1285, -1465, 1465, 384, -384,
-1215, 1215, -136, 136, 1218, -1218, -1335, 1335,
-874, 874, 220, -220, -1187, 1187, 1670, 1659,
-1185, 1185, -1530, 1530, -1278, 1278, 794, -794,
-1510, 1510, -854, 854, -870, 870, 478, -478,
-108, 108, -308, 308, 996, -996, 991, -991,
958, -958, -1460, 1460, 1522, -1522, 1628, -1628
};
private final int mlKem_k; private final int mlKem_k;
private final int mlKem_eta1; private final int mlKem_eta1;
private final int mlKem_eta2; private final int mlKem_eta2;
@ -969,11 +707,9 @@ public final class ML_KEM {
return vector; return vector;
} }
static void implMlKemNtt(short[] poly, short[] ntt_zetas) { // The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q]
implMlKemNttJava(poly); // The elements of poly at return will be in the range of [0, ML_KEM_Q]
} private void mlKemNTT(short[] poly) {
private static void implMlKemNttJava(short[] poly) {
int[] coeffs = new int[ML_KEM_N]; int[] coeffs = new int[ML_KEM_N];
for (int m = 0; m < ML_KEM_N; m++) { for (int m = 0; m < ML_KEM_N; m++) {
coeffs[m] = poly[m]; coeffs[m] = poly[m];
@ -982,20 +718,12 @@ public final class ML_KEM {
for (int m = 0; m < ML_KEM_N; m++) { for (int m = 0; m < ML_KEM_N; m++) {
poly[m] = (short) coeffs[m]; poly[m] = (short) coeffs[m];
} }
}
// The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q]
// The elements of poly at return will be in the range of [0, ML_KEM_Q]
private void mlKemNTT(short[] poly) {
implMlKemNtt(poly, MONT_ZETAS_FOR_VECTOR_NTT_ARR);
mlKemBarrettReduce(poly); mlKemBarrettReduce(poly);
} }
static void implMlKemInverseNtt(short[] poly, short[] zetas) { // Works in place, but also returns its (modified) input so that it can
implMlKemInverseNttJava(poly); // be used in expressions
} private short[] mlKemInverseNTT(short[] poly) {
private static void implMlKemInverseNttJava(short[] poly) {
int[] coeffs = new int[ML_KEM_N]; int[] coeffs = new int[ML_KEM_N];
for (int m = 0; m < ML_KEM_N; m++) { for (int m = 0; m < ML_KEM_N; m++) {
coeffs[m] = poly[m]; coeffs[m] = poly[m];
@ -1004,12 +732,6 @@ public final class ML_KEM {
for (int m = 0; m < ML_KEM_N; m++) { for (int m = 0; m < ML_KEM_N; m++) {
poly[m] = (short) coeffs[m]; poly[m] = (short) coeffs[m];
} }
}
// Works in place, but also returns its (modified) input so that it can
// be used in expressions
private short[] mlKemInverseNTT(short[] poly) {
implMlKemInverseNtt(poly, MONT_ZETAS_FOR_VECTOR_INVERSE_NTT_ARR);
return poly; return poly;
} }
@ -1100,14 +822,10 @@ public final class ML_KEM {
return result; return result;
} }
static void implMlKemNttMult(short[] result, short[] ntta, short[] nttb, // Multiplies two polynomials represented in the NTT domain.
short[] zetas) { // The result is a representation of the product still in the NTT domain.
implMlKemNttMultJava(result, ntta, nttb); // The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q).
} private void nttMult(short[] result, short[] ntta, short[] nttb) {
private static void implMlKemNttMultJava(short[] result,
short[] ntta, short[] nttb) {
for (int m = 0; m < ML_KEM_N / 2; m++) { for (int m = 0; m < ML_KEM_N / 2; m++) {
int a0 = ntta[2 * m]; int a0 = ntta[2 * m];
int a1 = ntta[2 * m + 1]; int a1 = ntta[2 * m + 1];
@ -1121,13 +839,6 @@ public final class ML_KEM {
} }
} }
// Multiplies two polynomials represented in the NTT domain.
// The result is a representation of the product still in the NTT domain.
// The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q).
private void nttMult(short[] result, short[] ntta, short[] nttb) {
implMlKemNttMult(result, ntta, nttb, MONT_ZETAS_FOR_VECTOR_NTT_MULT_ARR);
}
// Adds the vector of polynomials b to a in place, i.e. a will hold // Adds the vector of polynomials b to a in place, i.e. a will hold
// the result. It also returns (the modified) a so that it can be used // the result. It also returns (the modified) a so that it can be used
// in an expression. // in an expression.
@ -1142,36 +853,15 @@ public final class ML_KEM {
return a; return a;
} }
static void implMlKemAddPoly(short[] result, short[] a, short[] b) {
implMlKemAddPolyJava(result, a, b);
}
private static void implMlKemAddPolyJava(short[] result, short[] a, short[] b) {
for (int m = 0; m < ML_KEM_N; m++) {
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q
result[m] = (short) r;
}
}
// Adds the polynomial b to a in place, i.e. (the modified) a will hold // Adds the polynomial b to a in place, i.e. (the modified) a will hold
// the result. // the result.
// The coefficients are supposed be greater than -ML_KEM_Q in a and // The coefficients are supposed be greater than -ML_KEM_Q in a and
// greater than -ML_KEM_Q and less than ML_KEM_Q in b. // greater than -ML_KEM_Q and less than ML_KEM_Q in b.
// The coefficients in the result are greater than -ML_KEM_Q. // The coefficients in the result are greater than -ML_KEM_Q.
private void mlKemAddPoly(short[] a, short[] b) { private void mlKemAddPoly(short[] a, short[] b) {
implMlKemAddPoly(a, a, b);
}
static void implMlKemAddPoly(short[] result, short[] a, short[] b, short[] c) {
implMlKemAddPolyJava(result, a, b, c);
}
private static void implMlKemAddPolyJava(short[] result, short[] a,
short[] b, short[] c) {
for (int m = 0; m < ML_KEM_N; m++) { for (int m = 0; m < ML_KEM_N; m++) {
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q
result[m] = (short) r; a[m] = (short) r;
} }
} }
@ -1181,7 +871,10 @@ public final class ML_KEM {
// greater than -ML_KEM_Q and less than ML_KEM_Q. // greater than -ML_KEM_Q and less than ML_KEM_Q.
// The coefficients in the result are nonnegative and less than ML_KEM_Q. // The coefficients in the result are nonnegative and less than ML_KEM_Q.
private short[] mlKemAddPoly(short[] a, short[] b, short[] c) { private short[] mlKemAddPoly(short[] a, short[] b, short[] c) {
implMlKemAddPoly(a, a, b, c); for (int m = 0; m < ML_KEM_N; m++) {
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
a[m] = (short) r;
}
mlKemBarrettReduce(a); mlKemBarrettReduce(a);
return a; return a;
} }
@ -1304,23 +997,6 @@ public final class ML_KEM {
return result; return result;
} }
private static void implMlKem12To16(byte[] condensed, int index,
short[] parsed, int parsedLength) {
implMlKem12To16Java(condensed, index, parsed, parsedLength);
}
private static void implMlKem12To16Java(byte[] condensed, int index,
short[] parsed, int parsedLength) {
for (int i = 0; i < parsedLength * 3 / 2; i += 3) {
parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) +
256 * (condensed[i + index + 1] & 0xf));
parsed[(i / 3) * 2 + 1] = (short) (((condensed[i + index + 1] >>> 4) & 0xf) +
16 * (condensed[i + index + 2] & 0xff));
}
}
// The intrinsic implementations assume that the input and output buffers // The intrinsic implementations assume that the input and output buffers
// are such that condensed can be read in 192-byte chunks and // are such that condensed can be read in 192-byte chunks and
// parsed can be written in 128 shorts chunks. In other words, // parsed can be written in 128 shorts chunks. In other words,
@ -1330,7 +1006,12 @@ public final class ML_KEM {
private void twelve2Sixteen(byte[] condensed, int index, private void twelve2Sixteen(byte[] condensed, int index,
short[] parsed, int parsedLength) { short[] parsed, int parsedLength) {
implMlKem12To16(condensed, index, parsed, parsedLength); for (int i = 0; i < parsedLength * 3 / 2; i += 3) {
parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) +
256 * (condensed[i + index + 1] & 0xf));
parsed[(i / 3) * 2 + 1] = (short) (((condensed[i + index + 1] >>> 4) & 0xf) +
16 * (condensed[i + index + 2] & 0xff));
}
} }
private static void decodePoly5(byte[] condensed, int index, short[] parsed) { private static void decodePoly5(byte[] condensed, int index, short[] parsed) {
@ -1471,18 +1152,6 @@ public final class ML_KEM {
return result; return result;
} }
static void implMlKemBarrettReduce(short[] coeffs) {
implMlKemBarrettReduceJava(coeffs);
}
private static void implMlKemBarrettReduceJava(short[] coeffs) {
for (int m = 0; m < ML_KEM_N; m++) {
int tmp = ((int) coeffs[m] * BARRETT_MULTIPLIER) >>
BARRETT_SHIFT;
coeffs[m] = (short) (coeffs[m] - tmp * ML_KEM_Q);
}
}
// The input elements can have any short value. // The input elements can have any short value.
// Modifies poly such that upon return poly[i] will be // Modifies poly such that upon return poly[i] will be
// in the range [0, ML_KEM_Q] and will be congruent with the original // in the range [0, ML_KEM_Q] and will be congruent with the original
@ -1493,7 +1162,10 @@ public final class ML_KEM {
// will be in the range [0, ML_KEM_Q), i.e. it will be the canonical // will be in the range [0, ML_KEM_Q), i.e. it will be the canonical
// representative of its residue class. // representative of its residue class.
private void mlKemBarrettReduce(short[] poly) { private void mlKemBarrettReduce(short[] poly) {
implMlKemBarrettReduce(poly); for (int m = 0; m < ML_KEM_N; m++) {
int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT;
poly[m] = (short) (poly[m] - tmp * ML_KEM_Q);
}
} }
// Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q // Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q
@ -1503,8 +1175,8 @@ public final class ML_KEM {
int a = b * c; int a = b * c;
int aHigh = a >> MONT_R_BITS; int aHigh = a >> MONT_R_BITS;
int aLow = a & ((1 << MONT_R_BITS) - 1); int aLow = a & ((1 << MONT_R_BITS) - 1);
int m = ((MONT_Q_INV_MOD_R * aLow) << (32 - MONT_R_BITS)) >> // signed low product
(32 - MONT_R_BITS); // signed low product int m = ((MONT_Q_INV_MOD_R * aLow) << (32 - MONT_R_BITS)) >> (32 - MONT_R_BITS);
return (aHigh - ((m * MONT_Q) >> MONT_R_BITS)); // subtract signed high product return (aHigh - ((m * MONT_Q) >> MONT_R_BITS)); // subtract signed high product
} }

View file

@ -37,19 +37,6 @@ import javax.crypto.DecapsulateException;
public final class ML_KEM_Impls { public final class ML_KEM_Impls {
static int name2int(String name) {
if (name.endsWith("512")) {
return 512;
} else if (name.endsWith("768")) {
return 768;
} else if (name.endsWith("1024")) {
return 1024;
} else {
// should not happen
throw new ProviderException("Unknown name " + name);
}
}
public sealed static class KPG public sealed static class KPG
extends NamedKeyPairGenerator permits KPG2, KPG3, KPG5 { extends NamedKeyPairGenerator permits KPG2, KPG3, KPG5 {
@ -164,17 +151,8 @@ public final class ML_KEM_Impls {
ML_KEM mlKem = new ML_KEM(name); ML_KEM mlKem = new ML_KEM(name);
var kpkeCipherText = new ML_KEM.K_PKE_CipherText(cipherText); var kpkeCipherText = new ML_KEM.K_PKE_CipherText(cipherText);
return mlKem.decapsulate(new ML_KEM.ML_KEM_DecapsulationKey(
byte[] decapsulateResult;
try {
decapsulateResult = mlKem.decapsulate(
new ML_KEM.ML_KEM_DecapsulationKey(
decapsulationKey), kpkeCipherText); decapsulationKey), kpkeCipherText);
} catch (DecapsulateException e) {
throw new DecapsulateException("Decapsulate error", e) ;
}
return decapsulateResult;
} }
@Override @Override

View file

@ -25,7 +25,6 @@
package sun.security.provider; package sun.security.provider;
import jdk.internal.vm.annotation.IntrinsicCandidate;
import sun.security.provider.SHA3.SHAKE128; import sun.security.provider.SHA3.SHAKE128;
import sun.security.provider.SHA3.SHAKE256; import sun.security.provider.SHA3.SHAKE256;
@ -1032,12 +1031,6 @@ public class ML_DSA {
*/ */
public static int[] mlDsaNtt(int[] coeffs) { public static int[] mlDsaNtt(int[] coeffs) {
implMlDsaAlmostNttJava(coeffs);
implMlDsaMontMulByConstantJava(coeffs, MONT_R_MOD_Q);
return coeffs;
}
static void implMlDsaAlmostNttJava(int[] coeffs) {
int dimension = ML_DSA_N; int dimension = ML_DSA_N;
int m = 0; int m = 0;
for (int l = dimension / 2; l > 0; l /= 2) { for (int l = dimension / 2; l > 0; l /= 2) {
@ -1050,15 +1043,11 @@ public class ML_DSA {
m++; m++;
} }
} }
} montMulByConstant(coeffs, MONT_R_MOD_Q);
public static int[] mlDsaInverseNtt(int[] coeffs) {
implMlDsaAlmostInverseNttJava(coeffs);
implMlDsaMontMulByConstantJava(coeffs, MONT_DIM_INVERSE);
return coeffs; return coeffs;
} }
static void implMlDsaAlmostInverseNttJava(int[] coeffs) { public static int[] mlDsaInverseNtt(int[] coeffs) {
int dimension = ML_DSA_N; int dimension = ML_DSA_N;
int m = 0; int m = 0;
for (int l = 1; l < dimension; l *= 2) { for (int l = 1; l < dimension; l *= 2) {
@ -1072,6 +1061,8 @@ public class ML_DSA {
m++; m++;
} }
} }
montMulByConstant(coeffs, MONT_DIM_INVERSE);
return coeffs;
} }
void mlDsaVectorNtt(int[][] vector) { void mlDsaVectorNtt(int[][] vector) {
@ -1086,22 +1077,13 @@ public class ML_DSA {
} }
} }
//Todo public static void mlDsaNttMultiply(int[] product, int[] coeffs1, int[] coeffs2) {
public static void mlDsaNttMultiply(int[] res, int[] coeffs1, int[] coeffs2) {
implMlDsaNttMultJava(res, coeffs1, coeffs2);
}
static void implMlDsaNttMultJava(int[] product, int[] coeffs1, int[] coeffs2) {
for (int i = 0; i < ML_DSA_N; i++) { for (int i = 0; i < ML_DSA_N; i++) {
product[i] = montMul(coeffs1[i], toMont(coeffs2[i])); product[i] = montMul(coeffs1[i], toMont(coeffs2[i]));
} }
} }
public static void montMulByConstant(int[] coeffs, int constant) { public static void montMulByConstant(int[] coeffs, int constant) {
implMlDsaMontMulByConstantJava(coeffs, constant);
}
static void implMlDsaMontMulByConstantJava(int[] coeffs, int constant) {
for (int i = 0; i < ML_DSA_N; i++) { for (int i = 0; i < ML_DSA_N; i++) {
coeffs[i] = montMul((coeffs[i]), constant); coeffs[i] = montMul((coeffs[i]), constant);
} }
@ -1109,17 +1091,6 @@ public class ML_DSA {
public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart, public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) { int twoGamma2, int multiplier) {
implMlDsaDecomposePoly(input, lowPart, highPart, twoGamma2, multiplier);
}
@IntrinsicCandidate
static void implMlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
decomposePolyJava(input, lowPart, highPart, twoGamma2, multiplier);
}
static void decomposePolyJava(int[] input, int[] lowPart, int[] highPart,
int twoGamma2, int multiplier) {
for (int m = 0; m < ML_DSA_N; m++) { for (int m = 0; m < ML_DSA_N; m++) {
int rplus = input[m]; int rplus = input[m];
rplus = rplus - ((rplus + 5373807) >> 23) * ML_DSA_Q; rplus = rplus - ((rplus + 5373807) >> 23) * ML_DSA_Q;