diff --git a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java index 40bbc6c56cc..9808a013303 100644 --- a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java +++ b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java @@ -71,249 +71,6 @@ public final class ML_KEM { -1599, -709, -789, -1317, -57, 1049, -584 }; - private static final short[] MONT_ZETAS_FOR_VECTOR_NTT_ARR = new short[]{ - // level 0 - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - -758, -758, -758, -758, -758, -758, -758, -758, - // level 1 - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - // level 2 - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - 287, 287, 287, 287, 287, 287, 287, 287, - 287, 287, 287, 287, 287, 287, 287, 287, - 287, 287, 287, 287, 287, 287, 287, 287, - 287, 287, 287, 287, 287, 287, 287, 287, - 202, 202, 202, 202, 202, 202, 202, 202, - 202, 202, 202, 202, 202, 202, 202, 202, - 202, 202, 202, 202, 202, 202, 202, 202, - 202, 202, 202, 202, 202, 202, 202, 202, - // level 3 - -171, -171, -171, -171, -171, -171, -171, -171, - -171, -171, -171, -171, -171, -171, -171, -171, - 622, 622, 622, 622, 622, 622, 622, 622, - 622, 622, 622, 622, 622, 622, 622, 622, - 1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577, - 1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577, - 182, 182, 182, 182, 182, 182, 182, 182, - 182, 182, 182, 182, 182, 182, 182, 182, - 962, 962, 962, 962, 962, 962, 962, 962, - 962, 962, 962, 962, 962, 962, 962, 962, - -1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202, - -1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202, - -1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474, - -1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474, - 1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468, - 1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468, - // level 4 - 573, 573, 573, 573, 573, 573, 573, 573, - -1325, -1325, -1325, -1325, -1325, -1325, -1325, -1325, - 264, 264, 264, 264, 264, 264, 264, 264, - 383, 383, 383, 383, 383, 383, 383, 383, - -829, -829, -829, -829, -829, -829, -829, -829, - 1458, 1458, 1458, 1458, 1458, 1458, 1458, 1458, - -1602, -1602, -1602, -1602, -1602, -1602, -1602, -1602, - -130, -130, -130, -130, -130, -130, -130, -130, - -681, -681, -681, -681, -681, -681, -681, -681, - 1017, 1017, 1017, 1017, 1017, 1017, 1017, 1017, - 732, 732, 732, 732, 732, 732, 732, 732, - 608, 608, 608, 608, 608, 608, 608, 608, - -1542, -1542, -1542, -1542, -1542, -1542, -1542, -1542, - 411, 411, 411, 411, 411, 411, 411, 411, - -205, -205, -205, -205, -205, -205, -205, -205, - -1571, -1571, -1571, -1571, -1571, -1571, -1571, -1571, - // level 5 - 1223, 1223, 1223, 1223, 652, 652, 652, 652, - -552, -552, -552, -552, 1015, 1015, 1015, 1015, - -1293, -1293, -1293, -1293, 1491, 1491, 1491, 1491, - -282, -282, -282, -282, -1544, -1544, -1544, -1544, - 516, 516, 516, 516, -8, -8, -8, -8, - -320, -320, -320, -320, -666, -666, -666, -666, - 1711, 1711, 1711, 1711, -1162, -1162, -1162, -1162, - 126, 126, 126, 126, 1469, 1469, 1469, 1469, - -853, -853, -853, -853, -90, -90, -90, -90, - -271, -271, -271, -271, 830, 830, 830, 830, - 107, 107, 107, 107, -1421, -1421, -1421, -1421, - -247, -247, -247, -247, -951, -951, -951, -951, - -398, -398, -398, -398, 961, 961, 961, 961, - -1508, -1508, -1508, -1508, -725, -725, -725, -725, - 448, 448, 448, 448, -1065, -1065, -1065, -1065, - 677, 677, 677, 677, -1275, -1275, -1275, -1275, - // level 6 - -1103, -1103, 430, 430, 555, 555, 843, 843, - -1251, -1251, 871, 871, 1550, 1550, 105, 105, - 422, 422, 587, 587, 177, 177, -235, -235, - -291, -291, -460, -460, 1574, 1574, 1653, 1653, - -246, -246, 778, 778, 1159, 1159, -147, -147, - -777, -777, 1483, 1483, -602, -602, 1119, 1119, - -1590, -1590, 644, 644, -872, -872, 349, 349, - 418, 418, 329, 329, -156, -156, -75, -75, - 817, 817, 1097, 1097, 603, 603, 610, 610, - 1322, 1322, -1285, -1285, -1465, -1465, 384, 384, - -1215, -1215, -136, -136, 1218, 1218, -1335, -1335, - -874, -874, 220, 220, -1187, -1187, 1670, 1670, - -1185, -1185, -1530, -1530, -1278, -1278, 794, 794, - -1510, -1510, -854, -854, -870, -870, 478, 478, - -108, -108, -308, -308, 996, 996, 991, 991, - 958, 958, -1460, -1460, 1522, 1522, 1628, 1628 - }; - private static final short[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT_ARR = new short[]{ - // level 0 - -1628, -1628, -1522, -1522, 1460, 1460, -958, -958, - -991, -991, -996, -996, 308, 308, 108, 108, - -478, -478, 870, 870, 854, 854, 1510, 1510, - -794, -794, 1278, 1278, 1530, 1530, 1185, 1185, - 1659, 1659, 1187, 1187, -220, -220, 874, 874, - 1335, 1335, -1218, -1218, 136, 136, 1215, 1215, - -384, -384, 1465, 1465, 1285, 1285, -1322, -1322, - -610, -610, -603, -603, -1097, -1097, -817, -817, - 75, 75, 156, 156, -329, -329, -418, -418, - -349, -349, 872, 872, -644, -644, 1590, 1590, - -1119, -1119, 602, 602, -1483, -1483, 777, 777, - 147, 147, -1159, -1159, -778, -778, 246, 246, - -1653, -1653, -1574, -1574, 460, 460, 291, 291, - 235, 235, -177, -177, -587, -587, -422, -422, - -105, -105, -1550, -1550, -871, -871, 1251, 1251, - -843, -843, -555, -555, -430, -430, 1103, 1103, - // level 1 - 1275, 1275, 1275, 1275, -677, -677, -677, -677, - 1065, 1065, 1065, 1065, -448, -448, -448, -448, - 725, 725, 725, 725, 1508, 1508, 1508, 1508, - -961, -961, -961, -961, 398, 398, 398, 398, - 951, 951, 951, 951, 247, 247, 247, 247, - 1421, 1421, 1421, 1421, -107, -107, -107, -107, - -830, -830, -830, -830, 271, 271, 271, 271, - 90, 90, 90, 90, 853, 853, 853, 853, - -1469, -1469, -1469, -1469, -126, -126, -126, -126, - 1162, 1162, 1162, 1162, 1618, 1618, 1618, 1618, - 666, 666, 666, 666, 320, 320, 320, 320, - 8, 8, 8, 8, -516, -516, -516, -516, - 1544, 1544, 1544, 1544, 282, 282, 282, 282, - -1491, -1491, -1491, -1491, 1293, 1293, 1293, 1293, - -1015, -1015, -1015, -1015, 552, 552, 552, 552, - -652, -652, -652, -652, -1223, -1223, -1223, -1223, - // level 2 - 1571, 1571, 1571, 1571, 1571, 1571, 1571, 1571, - 205, 205, 205, 205, 205, 205, 205, 205, - -411, -411, -411, -411, -411, -411, -411, -411, - 1542, 1542, 1542, 1542, 1542, 1542, 1542, 1542, - -608, -608, -608, -608, -608, -608, -608, -608, - -732, -732, -732, -732, -732, -732, -732, -732, - -1017, -1017, -1017, -1017, -1017, -1017, -1017, -1017, - 681, 681, 681, 681, 681, 681, 681, 681, - 130, 130, 130, 130, 130, 130, 130, 130, - 1602, 1602, 1602, 1602, 1602, 1602, 1602, 1602, - -1458, -1458, -1458, -1458, -1458, -1458, -1458, -1458, - 829, 829, 829, 829, 829, 829, 829, 829, - -383, -383, -383, -383, -383, -383, -383, -383, - -264, -264, -264, -264, -264, -264, -264, -264, - 1325, 1325, 1325, 1325, 1325, 1325, 1325, 1325, - -573, -573, -573, -573, -573, -573, -573, -573, - // level 3 - -1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468, - -1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468, - 1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474, - 1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474, - 1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202, - 1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202, - -962, -962, -962, -962, -962, -962, -962, -962, - -962, -962, -962, -962, -962, -962, -962, -962, - -182, -182, -182, -182, -182, -182, -182, -182, - -182, -182, -182, -182, -182, -182, -182, -182, - -1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577, - -1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577, - -622, -622, -622, -622, -622, -622, -622, -622, - -622, -622, -622, -622, -622, -622, -622, -622, - 171, 171, 171, 171, 171, 171, 171, 171, - 171, 171, 171, 171, 171, 171, 171, 171, - // level 4 - -202, -202, -202, -202, -202, -202, -202, -202, - -202, -202, -202, -202, -202, -202, -202, -202, - -202, -202, -202, -202, -202, -202, -202, -202, - -202, -202, -202, -202, -202, -202, -202, -202, - -287, -287, -287, -287, -287, -287, -287, -287, - -287, -287, -287, -287, -287, -287, -287, -287, - -287, -287, -287, -287, -287, -287, -287, -287, - -287, -287, -287, -287, -287, -287, -287, -287, - -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, - -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, - -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, - -1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422, - -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, - -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, - -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, - -1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493, - // level 5 - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - 359, 359, 359, 359, 359, 359, 359, 359, - // level 6 - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758, - 758, 758, 758, 758, 758, 758, 758, 758 - }; - private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{ -1003, 1003, 222, -222, -1107, 1107, 172, -172, -42, 42, 620, -620, 1497, -1497, -1649, 1649, @@ -333,25 +90,6 @@ public final class ML_KEM { -1317, 1317, -57, 57, 1049, -1049, -584, 584 }; - private static final short[] MONT_ZETAS_FOR_VECTOR_NTT_MULT_ARR = new short[]{ - -1103, 1103, 430, -430, 555, -555, 843, -843, - -1251, 1251, 871, -871, 1550, -1550, 105, -105, - 422, -422, 587, -587, 177, -177, -235, 235, - -291, 291, -460, 460, 1574, -1574, 1653, -1653, - -246, 246, 778, -778, 1159, -1159, -147, 147, - -777, 777, 1483, -1483, -602, 602, 1119, -1119, - -1590, 1590, 644, -644, -872, 872, 349, -349, - 418, -418, 329, -329, -156, 156, -75, 75, - 817, -817, 1097, -1097, 603, -603, 610, -610, - 1322, -1322, -1285, 1285, -1465, 1465, 384, -384, - -1215, 1215, -136, 136, 1218, -1218, -1335, 1335, - -874, 874, 220, -220, -1187, 1187, 1670, 1659, - -1185, 1185, -1530, 1530, -1278, 1278, 794, -794, - -1510, 1510, -854, 854, -870, 870, 478, -478, - -108, 108, -308, 308, 996, -996, 991, -991, - 958, -958, -1460, 1460, 1522, -1522, 1628, -1628 - }; - private final int mlKem_k; private final int mlKem_eta1; private final int mlKem_eta2; @@ -499,7 +237,7 @@ public final class ML_KEM { System.arraycopy(kPkePrivateKey, 0, decapsKey, 0, kPkePrivateKey.length); Arrays.fill(kPkePrivateKey, (byte)0); System.arraycopy(encapsKey, 0, decapsKey, - kPkePrivateKey.length, encapsKey.length); + kPkePrivateKey.length, encapsKey.length); mlKemH.update(encapsKey); try { @@ -534,7 +272,7 @@ public final class ML_KEM { var kHatAndRandomCoins = mlKemG.digest(); var randomCoins = Arrays.copyOfRange(kHatAndRandomCoins, 32, 64); var cipherText = kPkeEncrypt(new K_PKE_EncryptionKey(encapsulationKey.keyBytes), - randomMessage, randomCoins); + randomMessage, randomCoins); Arrays.fill(randomCoins, (byte) 0); byte[] sharedSecret = Arrays.copyOfRange(kHatAndRandomCoins, 0, 32); Arrays.fill(kHatAndRandomCoins, (byte) 0); @@ -564,7 +302,7 @@ public final class ML_KEM { byte[] kPkePrivateKeyBytes = new byte[mlKem_k * encode12PolyLen]; System.arraycopy(decapsKeyBytes, 0, kPkePrivateKeyBytes, 0, - kPkePrivateKeyBytes.length); + kPkePrivateKeyBytes.length); byte[] encapsKeyBytes = new byte[mlKem_k * encode12PolyLen + 32]; System.arraycopy(decapsKeyBytes, mlKem_k * encode12PolyLen, @@ -678,8 +416,8 @@ public final class ML_KEM { pkEncoded, (mlKem_k * ML_KEM_N * 12) / 8, rho.length); return new K_PKE_KeyPair( - new K_PKE_EncryptionKey(pkEncoded), - new K_PKE_DecryptionKey(skEncoded)); + new K_PKE_EncryptionKey(pkEncoded), + new K_PKE_DecryptionKey(skEncoded)); } private K_PKE_CipherText kPkeEncrypt( @@ -969,11 +707,9 @@ public final class ML_KEM { return vector; } - static void implMlKemNtt(short[] poly, short[] ntt_zetas) { - implMlKemNttJava(poly); - } - - private static void implMlKemNttJava(short[] poly) { + // The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q] + // The elements of poly at return will be in the range of [0, ML_KEM_Q] + private void mlKemNTT(short[] poly) { int[] coeffs = new int[ML_KEM_N]; for (int m = 0; m < ML_KEM_N; m++) { coeffs[m] = poly[m]; @@ -982,20 +718,12 @@ public final class ML_KEM { for (int m = 0; m < ML_KEM_N; m++) { poly[m] = (short) coeffs[m]; } - } - - // The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q] - // The elements of poly at return will be in the range of [0, ML_KEM_Q] - private void mlKemNTT(short[] poly) { - implMlKemNtt(poly, MONT_ZETAS_FOR_VECTOR_NTT_ARR); mlKemBarrettReduce(poly); } - static void implMlKemInverseNtt(short[] poly, short[] zetas) { - implMlKemInverseNttJava(poly); - } - - private static void implMlKemInverseNttJava(short[] poly) { + // Works in place, but also returns its (modified) input so that it can + // be used in expressions + private short[] mlKemInverseNTT(short[] poly) { int[] coeffs = new int[ML_KEM_N]; for (int m = 0; m < ML_KEM_N; m++) { coeffs[m] = poly[m]; @@ -1004,12 +732,6 @@ public final class ML_KEM { for (int m = 0; m < ML_KEM_N; m++) { poly[m] = (short) coeffs[m]; } - } - - // Works in place, but also returns its (modified) input so that it can - // be used in expressions - private short[] mlKemInverseNTT(short[] poly) { - implMlKemInverseNtt(poly, MONT_ZETAS_FOR_VECTOR_INVERSE_NTT_ARR); return poly; } @@ -1100,14 +822,10 @@ public final class ML_KEM { return result; } - static void implMlKemNttMult(short[] result, short[] ntta, short[] nttb, - short[] zetas) { - implMlKemNttMultJava(result, ntta, nttb); - } - - private static void implMlKemNttMultJava(short[] result, - short[] ntta, short[] nttb) { - + // Multiplies two polynomials represented in the NTT domain. + // The result is a representation of the product still in the NTT domain. + // The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q). + private void nttMult(short[] result, short[] ntta, short[] nttb) { for (int m = 0; m < ML_KEM_N / 2; m++) { int a0 = ntta[2 * m]; int a1 = ntta[2 * m + 1]; @@ -1121,13 +839,6 @@ public final class ML_KEM { } } - // Multiplies two polynomials represented in the NTT domain. - // The result is a representation of the product still in the NTT domain. - // The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q). - private void nttMult(short[] result, short[] ntta, short[] nttb) { - implMlKemNttMult(result, ntta, nttb, MONT_ZETAS_FOR_VECTOR_NTT_MULT_ARR); - } - // Adds the vector of polynomials b to a in place, i.e. a will hold // the result. It also returns (the modified) a so that it can be used // in an expression. @@ -1142,36 +853,15 @@ public final class ML_KEM { return a; } - static void implMlKemAddPoly(short[] result, short[] a, short[] b) { - implMlKemAddPolyJava(result, a, b); - } - - private static void implMlKemAddPolyJava(short[] result, short[] a, short[] b) { - for (int m = 0; m < ML_KEM_N; m++) { - int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q - result[m] = (short) r; - } - } - // Adds the polynomial b to a in place, i.e. (the modified) a will hold // the result. // The coefficients are supposed be greater than -ML_KEM_Q in a and // greater than -ML_KEM_Q and less than ML_KEM_Q in b. // The coefficients in the result are greater than -ML_KEM_Q. private void mlKemAddPoly(short[] a, short[] b) { - implMlKemAddPoly(a, a, b); - } - - static void implMlKemAddPoly(short[] result, short[] a, short[] b, short[] c) { - implMlKemAddPolyJava(result, a, b, c); - } - - private static void implMlKemAddPolyJava(short[] result, short[] a, - short[] b, short[] c) { - for (int m = 0; m < ML_KEM_N; m++) { - int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q - result[m] = (short) r; + int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q + a[m] = (short) r; } } @@ -1181,7 +871,10 @@ public final class ML_KEM { // greater than -ML_KEM_Q and less than ML_KEM_Q. // The coefficients in the result are nonnegative and less than ML_KEM_Q. private short[] mlKemAddPoly(short[] a, short[] b, short[] c) { - implMlKemAddPoly(a, a, b, c); + for (int m = 0; m < ML_KEM_N; m++) { + int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q + a[m] = (short) r; + } mlKemBarrettReduce(a); return a; } @@ -1304,23 +997,6 @@ public final class ML_KEM { return result; } - private static void implMlKem12To16(byte[] condensed, int index, - short[] parsed, int parsedLength) { - - implMlKem12To16Java(condensed, index, parsed, parsedLength); - } - - private static void implMlKem12To16Java(byte[] condensed, int index, - short[] parsed, int parsedLength) { - - for (int i = 0; i < parsedLength * 3 / 2; i += 3) { - parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) + - 256 * (condensed[i + index + 1] & 0xf)); - parsed[(i / 3) * 2 + 1] = (short) (((condensed[i + index + 1] >>> 4) & 0xf) + - 16 * (condensed[i + index + 2] & 0xff)); - } - } - // The intrinsic implementations assume that the input and output buffers // are such that condensed can be read in 192-byte chunks and // parsed can be written in 128 shorts chunks. In other words, @@ -1330,7 +1006,12 @@ public final class ML_KEM { private void twelve2Sixteen(byte[] condensed, int index, short[] parsed, int parsedLength) { - implMlKem12To16(condensed, index, parsed, parsedLength); + for (int i = 0; i < parsedLength * 3 / 2; i += 3) { + parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) + + 256 * (condensed[i + index + 1] & 0xf)); + parsed[(i / 3) * 2 + 1] = (short) (((condensed[i + index + 1] >>> 4) & 0xf) + + 16 * (condensed[i + index + 2] & 0xff)); + } } private static void decodePoly5(byte[] condensed, int index, short[] parsed) { @@ -1471,18 +1152,6 @@ public final class ML_KEM { return result; } - static void implMlKemBarrettReduce(short[] coeffs) { - implMlKemBarrettReduceJava(coeffs); - } - - private static void implMlKemBarrettReduceJava(short[] coeffs) { - for (int m = 0; m < ML_KEM_N; m++) { - int tmp = ((int) coeffs[m] * BARRETT_MULTIPLIER) >> - BARRETT_SHIFT; - coeffs[m] = (short) (coeffs[m] - tmp * ML_KEM_Q); - } - } - // The input elements can have any short value. // Modifies poly such that upon return poly[i] will be // in the range [0, ML_KEM_Q] and will be congruent with the original @@ -1493,7 +1162,10 @@ public final class ML_KEM { // will be in the range [0, ML_KEM_Q), i.e. it will be the canonical // representative of its residue class. private void mlKemBarrettReduce(short[] poly) { - implMlKemBarrettReduce(poly); + for (int m = 0; m < ML_KEM_N; m++) { + int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT; + poly[m] = (short) (poly[m] - tmp * ML_KEM_Q); + } } // Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q @@ -1503,8 +1175,8 @@ public final class ML_KEM { int a = b * c; int aHigh = a >> MONT_R_BITS; int aLow = a & ((1 << MONT_R_BITS) - 1); - int m = ((MONT_Q_INV_MOD_R * aLow) << (32 - MONT_R_BITS)) >> - (32 - MONT_R_BITS); // signed low product + // signed low product + int m = ((MONT_Q_INV_MOD_R * aLow) << (32 - MONT_R_BITS)) >> (32 - MONT_R_BITS); return (aHigh - ((m * MONT_Q) >> MONT_R_BITS)); // subtract signed high product } diff --git a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM_Impls.java b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM_Impls.java index f59883a410e..2ce5b3324e7 100644 --- a/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM_Impls.java +++ b/src/java.base/share/classes/com/sun/crypto/provider/ML_KEM_Impls.java @@ -37,19 +37,6 @@ import javax.crypto.DecapsulateException; public final class ML_KEM_Impls { - static int name2int(String name) { - if (name.endsWith("512")) { - return 512; - } else if (name.endsWith("768")) { - return 768; - } else if (name.endsWith("1024")) { - return 1024; - } else { - // should not happen - throw new ProviderException("Unknown name " + name); - } - } - public sealed static class KPG extends NamedKeyPairGenerator permits KPG2, KPG3, KPG5 { @@ -164,17 +151,8 @@ public final class ML_KEM_Impls { ML_KEM mlKem = new ML_KEM(name); var kpkeCipherText = new ML_KEM.K_PKE_CipherText(cipherText); - - byte[] decapsulateResult; - try { - decapsulateResult = mlKem.decapsulate( - new ML_KEM.ML_KEM_DecapsulationKey( - decapsulationKey), kpkeCipherText); - } catch (DecapsulateException e) { - throw new DecapsulateException("Decapsulate error", e) ; - } - - return decapsulateResult; + return mlKem.decapsulate(new ML_KEM.ML_KEM_DecapsulationKey( + decapsulationKey), kpkeCipherText); } @Override diff --git a/src/java.base/share/classes/sun/security/provider/ML_DSA.java b/src/java.base/share/classes/sun/security/provider/ML_DSA.java index 33ce8d5f52d..f3b72f53a32 100644 --- a/src/java.base/share/classes/sun/security/provider/ML_DSA.java +++ b/src/java.base/share/classes/sun/security/provider/ML_DSA.java @@ -25,7 +25,6 @@ package sun.security.provider; -import jdk.internal.vm.annotation.IntrinsicCandidate; import sun.security.provider.SHA3.SHAKE128; import sun.security.provider.SHA3.SHAKE256; @@ -527,13 +526,13 @@ public class ML_DSA { int tOffset = j*4; int vOffset = (i*320) + (j*5); t1[i][tOffset] = (v[vOffset] & 0xFF) + - ((v[vOffset+1] << 8) & 0x3FF); + ((v[vOffset+1] << 8) & 0x3FF); t1[i][tOffset+1] = ((v[vOffset+1] >> 2) & 0x3F) + - ((v[vOffset+2] << 6) & 0x3FF); + ((v[vOffset+2] << 6) & 0x3FF); t1[i][tOffset+2] = ((v[vOffset+2] >> 4) & 0xF) + - ((v[vOffset+3] << 4) & 0x3FF); + ((v[vOffset+3] << 4) & 0x3FF); t1[i][tOffset+3] = ((v[vOffset+3] >> 6) & 0x3) + - ((v[vOffset+4] << 2) & 0x3FF); + ((v[vOffset+4] << 2) & 0x3FF); } } return t1; @@ -875,8 +874,8 @@ public class ML_DSA { rawOfs = 0; } tmp = (rawAij[rawOfs] & 0xFF) + - ((rawAij[rawOfs + 1] & 0xFF) << 8) + - ((rawAij[rawOfs + 2] & 0x7F) << 16); + ((rawAij[rawOfs + 1] & 0xFF) << 8) + + ((rawAij[rawOfs + 2] & 0x7F) << 16); rawOfs += 3; if (tmp < ML_DSA_Q) { aij[ofs] = tmp; @@ -981,7 +980,7 @@ public class ML_DSA { int multiplier = (gamma2 == 95232 ? 22 : 8); for (int i = 0; i < mlDsa_k; i++) { ML_DSA.mlDsaDecomposePoly(input[i], lowPart[i], - highPart[i], gamma2 * 2, multiplier); + highPart[i], gamma2 * 2, multiplier); } } @@ -1032,12 +1031,6 @@ public class ML_DSA { */ public static int[] mlDsaNtt(int[] coeffs) { - implMlDsaAlmostNttJava(coeffs); - implMlDsaMontMulByConstantJava(coeffs, MONT_R_MOD_Q); - return coeffs; - } - - static void implMlDsaAlmostNttJava(int[] coeffs) { int dimension = ML_DSA_N; int m = 0; for (int l = dimension / 2; l > 0; l /= 2) { @@ -1050,15 +1043,11 @@ public class ML_DSA { m++; } } - } - - public static int[] mlDsaInverseNtt(int[] coeffs) { - implMlDsaAlmostInverseNttJava(coeffs); - implMlDsaMontMulByConstantJava(coeffs, MONT_DIM_INVERSE); + montMulByConstant(coeffs, MONT_R_MOD_Q); return coeffs; } - static void implMlDsaAlmostInverseNttJava(int[] coeffs) { + public static int[] mlDsaInverseNtt(int[] coeffs) { int dimension = ML_DSA_N; int m = 0; for (int l = 1; l < dimension; l *= 2) { @@ -1067,11 +1056,13 @@ public class ML_DSA { int tmp = coeffs[j]; coeffs[j] = (tmp + coeffs[j + l]); coeffs[j + l] = montMul(tmp - coeffs[j + l], - MONT_ZETAS_FOR_INVERSE_NTT[m]); + MONT_ZETAS_FOR_INVERSE_NTT[m]); } m++; } } + montMulByConstant(coeffs, MONT_DIM_INVERSE); + return coeffs; } void mlDsaVectorNtt(int[][] vector) { @@ -1086,22 +1077,13 @@ public class ML_DSA { } } - //Todo - public static void mlDsaNttMultiply(int[] res, int[] coeffs1, int[] coeffs2) { - implMlDsaNttMultJava(res, coeffs1, coeffs2); - } - - static void implMlDsaNttMultJava(int[] product, int[] coeffs1, int[] coeffs2) { + public static void mlDsaNttMultiply(int[] product, int[] coeffs1, int[] coeffs2) { for (int i = 0; i < ML_DSA_N; i++) { product[i] = montMul(coeffs1[i], toMont(coeffs2[i])); } } public static void montMulByConstant(int[] coeffs, int constant) { - implMlDsaMontMulByConstantJava(coeffs, constant); - } - - static void implMlDsaMontMulByConstantJava(int[] coeffs, int constant) { for (int i = 0; i < ML_DSA_N; i++) { coeffs[i] = montMul((coeffs[i]), constant); } @@ -1109,17 +1091,6 @@ public class ML_DSA { public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart, int twoGamma2, int multiplier) { - implMlDsaDecomposePoly(input, lowPart, highPart, twoGamma2, multiplier); - } - - @IntrinsicCandidate - static void implMlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart, - int twoGamma2, int multiplier) { - decomposePolyJava(input, lowPart, highPart, twoGamma2, multiplier); - } - - static void decomposePolyJava(int[] input, int[] lowPart, int[] highPart, - int twoGamma2, int multiplier) { for (int m = 0; m < ML_DSA_N; m++) { int rplus = input[m]; rplus = rplus - ((rplus + 5373807) >> 23) * ML_DSA_Q;