8271745: Correct block size for KW,KWP mode and use fixed IV for KWP mode for SunJCE

Reviewed-by: xuelei, mullan
This commit is contained in:
Valerie Peng 2021-09-01 22:17:49 +00:00
parent 2f01a6f8b6
commit 1a5a2b6b15
3 changed files with 194 additions and 131 deletions

View file

@ -26,6 +26,7 @@
package com.sun.crypto.provider; package com.sun.crypto.provider;
import java.util.Arrays; import java.util.Arrays;
import java.util.HexFormat;
import java.security.*; import java.security.*;
import java.security.spec.*; import java.security.spec.*;
import javax.crypto.*; import javax.crypto.*;
@ -132,12 +133,14 @@ class AESKeyWrapPadded extends FeedbackCipher {
if (key == null) { if (key == null) {
throw new InvalidKeyException("Invalid null key"); throw new InvalidKeyException("Invalid null key");
} }
if (iv != null && iv.length != ICV2.length) { // allow setting an iv but if non-null, must equal to ICV2
throw new InvalidAlgorithmParameterException("Invalid IV length"); if (iv != null && !Arrays.equals(iv, ICV2)) {
HexFormat hf = HexFormat.of().withUpperCase();
throw new InvalidAlgorithmParameterException("Invalid IV, got 0x" +
hf.formatHex(iv) + " instead of 0x" + hf.formatHex(ICV2));
} }
embeddedCipher.init(decrypting, algorithm, key); embeddedCipher.init(decrypting, algorithm, key);
// iv is retrieved from IvParameterSpec.getIV() which is already cloned this.iv = ICV2;
this.iv = (iv == null? ICV2 : iv);
} }
/** /**

View file

@ -70,28 +70,28 @@ abstract class KeyWrapCipher extends CipherSpi {
// for AES/KW/NoPadding // for AES/KW/NoPadding
public static final class AES_KW_PKCS5Padding extends KeyWrapCipher { public static final class AES_KW_PKCS5Padding extends KeyWrapCipher {
public AES_KW_PKCS5Padding() { public AES_KW_PKCS5Padding() {
super(new AESKeyWrap(), new PKCS5Padding(16), -1); super(new AESKeyWrap(), new PKCS5Padding(8), -1);
} }
} }
// for AES_128/KW/NoPadding // for AES_128/KW/NoPadding
public static final class AES128_KW_PKCS5Padding extends KeyWrapCipher { public static final class AES128_KW_PKCS5Padding extends KeyWrapCipher {
public AES128_KW_PKCS5Padding() { public AES128_KW_PKCS5Padding() {
super(new AESKeyWrap(), new PKCS5Padding(16), 16); super(new AESKeyWrap(), new PKCS5Padding(8), 16);
} }
} }
// for AES_192/KW/NoPadding // for AES_192/KW/NoPadding
public static final class AES192_KW_PKCS5Padding extends KeyWrapCipher { public static final class AES192_KW_PKCS5Padding extends KeyWrapCipher {
public AES192_KW_PKCS5Padding() { public AES192_KW_PKCS5Padding() {
super(new AESKeyWrap(), new PKCS5Padding(16), 24); super(new AESKeyWrap(), new PKCS5Padding(8), 24);
} }
} }
// for AES_256/KW/NoPadding // for AES_256/KW/NoPadding
public static final class AES256_KW_PKCS5Padding extends KeyWrapCipher { public static final class AES256_KW_PKCS5Padding extends KeyWrapCipher {
public AES256_KW_PKCS5Padding() { public AES256_KW_PKCS5Padding() {
super(new AESKeyWrap(), new PKCS5Padding(16), 32); super(new AESKeyWrap(), new PKCS5Padding(8), 32);
} }
} }
@ -230,13 +230,11 @@ abstract class KeyWrapCipher extends CipherSpi {
} }
/** /**
* Returns the block size (in bytes). i.e. 16 bytes. * @return the block size (in bytes)
*
* @return the block size (in bytes), i.e. 16 bytes.
*/ */
@Override @Override
protected int engineGetBlockSize() { protected int engineGetBlockSize() {
return cipher.getBlockSize(); return 8;
} }
/** /**

View file

@ -23,188 +23,250 @@
/* /*
* @test * @test
* @bug 8248268 8268621 * @bug 8248268 8268621 8271745
* @summary Verify general properties of the AES/KW/NoPadding, * @summary Verify general properties of the AES/KW/NoPadding,
* AES/KW/PKCS5Padding, and AES/KWP/NoPadding. * AES/KW/PKCS5Padding, and AES/KWP/NoPadding impls of SunJCE provider.
* @run main TestGeneral * @run main TestGeneral
*/ */
import java.nio.ByteBuffer;
import java.util.Arrays; import java.util.Arrays;
import java.security.Key; import java.security.Key;
import java.security.KeyPairGenerator;
import java.security.PrivateKey;
import java.security.InvalidAlgorithmParameterException; import java.security.InvalidAlgorithmParameterException;
import java.security.AlgorithmParameters;
import javax.crypto.*; import javax.crypto.*;
import javax.crypto.spec.*; import javax.crypto.spec.*;
public class TestGeneral { public class TestGeneral {
private static final byte[] DATA_128 = private static final byte[] DATA_32 =
Arrays.copyOf("1234567890123456789012345678901234".getBytes(), 128); Arrays.copyOf("1234567890123456789012345678901234".getBytes(), 32);
private static final SecretKey KEY = private static final SecretKey KEY =
new SecretKeySpec(DATA_128, 0, 16, "AES"); new SecretKeySpec(DATA_32, 0, 16, "AES");
private static final int KW_IV_LEN = 8; private static final int KW_IV_LEN = 8;
private static final int KWP_IV_LEN = 4; private static final int KWP_IV_LEN = 4;
private static final int MAX_KW_PKCS5PAD_LEN = 16; // 1-16 private static final int MAX_KW_PKCS5PAD_LEN = 8; // 1-8
private static final int MAX_KWP_PAD_LEN = 7; // 0...7 private static final int MAX_KWP_PAD_LEN = 7; // 0-7
public static void testEnc(Cipher c, byte[] in, int inLen, int ivLen, public static void testEnc(Cipher c, byte[] in, int startLen, int inc,
int maxPadLen) throws Exception { IvParameterSpec[] ivs, int maxPadLen) throws Exception {
System.out.println("input len: " + inLen); System.out.println("testEnc, input len=" + startLen + " w/ inc=" +
c.init(Cipher.ENCRYPT_MODE, KEY, new IvParameterSpec(in, 0, ivLen)); inc);
int estOutLen = c.getOutputSize(inLen); for (IvParameterSpec iv : ivs) {
System.out.print("\t=> w/ iv=" + iv);
byte[] out = c.doFinal(in, 0, inLen); for (int inLen = startLen; inLen < in.length; inLen+=inc) {
c.init(Cipher.ENCRYPT_MODE, KEY, iv);
// for encryption output, the estimate should match the actual int estOutLen = c.getOutputSize(inLen);
if (estOutLen != out.length) { System.out.println(", inLen=" + inLen);
System.out.println("=> estimated: " + estOutLen); byte[] out = c.doFinal(in, 0, inLen);
System.out.println("=> actual enc out length: " + out.length);
throw new RuntimeException("Failed enc output len check");
}
// encryption outout should always be multiple of 8 and at least 8-byte // check the length of encryption output
// longer than input if (estOutLen != out.length || (out.length % 8 != 0) ||
if ((out.length % 8 != 0) || (out.length - inLen < 8)) { (out.length - inLen < 8)) {
throw new RuntimeException("Invalid length of encrypted data: " + System.out.println("=> estimated: " + estOutLen);
out.length); System.out.println("=> actual: " + out.length);
} throw new RuntimeException("Failed enc output len check");
}
c.init(Cipher.DECRYPT_MODE, KEY, new IvParameterSpec(in, 0, ivLen)); c.init(Cipher.DECRYPT_MODE, KEY, iv);
estOutLen = c.getOutputSize(out.length); estOutLen = c.getOutputSize(out.length);
byte[] recovered = new byte[estOutLen];
byte[] in2 = c.doFinal(out); // do decryption using ByteBuffer and multi-part
ByteBuffer outBB = ByteBuffer.wrap(out);
ByteBuffer recoveredBB = ByteBuffer.wrap(recovered);
int len = c.update(outBB, recoveredBB);
len += c.doFinal(outBB, recoveredBB);
// for decryption output, the estimate should match the actual for // check the length of decryption output
// AES/KW/NoPadding and slightly larger than the actual for the rest if (estOutLen < len || (estOutLen - len) > maxPadLen) {
if (estOutLen < in2.length || (estOutLen - in2.length) > maxPadLen) { System.out.println("=> estimated: " + estOutLen);
System.out.println("=> estimated: " + estOutLen); System.out.println("=> actual: " + len);
System.out.println("=> actual dec out length: " + in2.length); throw new RuntimeException("Failed dec output len check");
throw new RuntimeException("Failed dec output len check"); }
}
if (!Arrays.equals(in, 0, inLen, in2, 0, inLen)) { if (!Arrays.equals(in, 0, inLen, recovered, 0, len)) {
throw new RuntimeException("Failed decrypted data check"); throw new RuntimeException("Failed decrypted data check");
}
}
} }
} }
public static void testWrap(Cipher c, byte[] in, int inLen, int ivLen, public static void testWrap(Cipher c, Key[] inKeys, IvParameterSpec[] ivs,
int maxPadLen) throws Exception { int maxPadLen) throws Exception {
System.out.println("key len: " + inLen); for (Key inKey : inKeys) {
c.init(Cipher.WRAP_MODE, KEY, new IvParameterSpec(in, 0, ivLen)); System.out.println("testWrap, key: " + inKey);
for (IvParameterSpec iv : ivs) {
System.out.println("\t=> w/ iv " + iv);
int estOutLen = c.getOutputSize(inLen); c.init(Cipher.WRAP_MODE, KEY, iv);
byte[] out = c.wrap(new SecretKeySpec(in, 0, inLen, "Any")); byte[] out = c.wrap(inKey);
// for encryption output, the estimate should match the actual // output should always be multiple of cipher block size
if (estOutLen != out.length) { if (out.length % c.getBlockSize() != 0) {
System.out.println("=> estimated: " + estOutLen); throw new RuntimeException("Invalid wrap len: " +
System.out.println("=> actual wrap out length: " + out.length); out.length);
throw new RuntimeException("Failed wrap output len check"); }
}
// encryption outout should always be multiple of 8 and at least 8-byte c.init(Cipher.UNWRAP_MODE, KEY, iv);
// longer than input
if ((out.length % 8 != 0) || (out.length - inLen < 8)) {
throw new RuntimeException("Invalid length of encrypted data: " +
out.length);
}
c.init(Cipher.UNWRAP_MODE, KEY, new IvParameterSpec(in, 0, ivLen));
estOutLen = c.getOutputSize(out.length);
Key key2 = c.unwrap(out, "Any", Cipher.SECRET_KEY); // SecretKey or PrivateKey
int keyType = (inKey instanceof SecretKey? Cipher.SECRET_KEY :
Cipher.PRIVATE_KEY);
if (!(key2 instanceof SecretKey)) { int estOutLen = c.getOutputSize(out.length);
throw new RuntimeException("Failed unwrap output type check"); Key key2 = c.unwrap(out, inKey.getAlgorithm(), keyType);
}
byte[] in2 = key2.getEncoded(); if ((keyType == Cipher.SECRET_KEY &&
// for decryption output, the estimate should match the actual for !(key2 instanceof SecretKey)) ||
// AES/KW/NoPadding and slightly larger than the actual for the rest (keyType == Cipher.PRIVATE_KEY &&
if (estOutLen < in2.length || (estOutLen - in2.length) > maxPadLen) { !(key2 instanceof PrivateKey))) {
System.out.println("=> estimated: " + estOutLen); throw new RuntimeException("Failed unwrap type check");
System.out.println("=> actual unwrap out length: " + in2.length); }
throw new RuntimeException("Failed unwrap output len check");
}
if (inLen != in2.length || byte[] in2 = key2.getEncoded();
!Arrays.equals(in, 0, inLen, in2, 0, inLen)) { // check decryption output length
throw new RuntimeException("Failed unwrap data check"); if (estOutLen < in2.length ||
(estOutLen - in2.length) > maxPadLen) {
System.out.println("=> estimated: " + estOutLen);
System.out.println("=> actual: " + in2.length);
throw new RuntimeException("Failed unwrap len check");
}
if (!Arrays.equals(inKey.getEncoded(), in2) ||
!(inKey.getAlgorithm().equalsIgnoreCase
(key2.getAlgorithm()))) {
throw new RuntimeException("Failed unwrap key check");
}
}
} }
} }
public static void testIv(Cipher c) throws Exception { public static void testIv(Cipher c, int defIvLen, boolean allowCustomIv)
throws Exception {
System.out.println("testIv: defIvLen = " + defIvLen +
" allowCustomIv = " + allowCustomIv);
// get a fresh Cipher instance so we can test iv with pre-init state // get a fresh Cipher instance so we can test iv with pre-init state
Cipher c2 = Cipher.getInstance(c.getAlgorithm(), c.getProvider()); c = Cipher.getInstance(c.getAlgorithm(), c.getProvider());
if (c2.getIV() != null) { if (c.getIV() != null) {
throw new RuntimeException("Expects null iv"); throw new RuntimeException("Expects null iv");
} }
if (c2.getParameters() == null) {
AlgorithmParameters ivParams = c.getParameters();
if (ivParams == null) {
throw new RuntimeException("Expects non-null default parameters"); throw new RuntimeException("Expects non-null default parameters");
} }
IvParameterSpec ivSpec =
c2.init(Cipher.ENCRYPT_MODE, KEY); ivParams.getParameterSpec(IvParameterSpec.class);
byte[] defIv2 = c2.getIV(); byte[] iv = ivSpec.getIV();
// try through all opmodes
c.init(Cipher.ENCRYPT_MODE, KEY); c.init(Cipher.ENCRYPT_MODE, KEY);
c.init(Cipher.DECRYPT_MODE, KEY);
c.init(Cipher.WRAP_MODE, KEY);
c.init(Cipher.UNWRAP_MODE, KEY);
byte[] defIv = c.getIV(); byte[] defIv = c.getIV();
if (!Arrays.equals(defIv, defIv2)) {
// try again through all opmodes
c.init(Cipher.ENCRYPT_MODE, KEY);
c.init(Cipher.DECRYPT_MODE, KEY);
c.init(Cipher.WRAP_MODE, KEY);
c.init(Cipher.UNWRAP_MODE, KEY);
byte[] defIv2 = c.getIV();
if (iv.length != defIvLen || !Arrays.equals(iv, defIv) ||
!Arrays.equals(defIv, defIv2)) {
throw new RuntimeException("Failed default iv check"); throw new RuntimeException("Failed default iv check");
} }
if (defIv == defIv2) {
throw new RuntimeException("Failed getIV copy check");
}
// try init w/ an iv w/ invalid length // try init w/ an iv w/ invalid length
try { try {
c.init(Cipher.ENCRYPT_MODE, KEY, new IvParameterSpec(defIv, 0, c.init(Cipher.ENCRYPT_MODE, KEY, new IvParameterSpec(defIv, 0,
defIv.length/2)); defIv.length/2));
throw new RuntimeException("Invalid iv accepted"); throw new RuntimeException("Invalid iv accepted");
} catch (InvalidAlgorithmParameterException iape) { } catch (InvalidAlgorithmParameterException iape) {
System.out.println("Invalid IV rejected as expected"); System.out.println("Invalid IV rejected as expected");
} }
Arrays.fill(defIv, (byte) 0xFF);
c.init(Cipher.ENCRYPT_MODE, KEY, new IvParameterSpec(defIv)); if (allowCustomIv) {
byte[] newIv = c.getIV(); Arrays.fill(defIv, (byte) 0xFF);
if (!Arrays.equals(newIv, defIv)) { // try through all opmodes
throw new RuntimeException("Failed set iv check"); c.init(Cipher.ENCRYPT_MODE, KEY, new IvParameterSpec(defIv));
} c.init(Cipher.DECRYPT_MODE, KEY, new IvParameterSpec(defIv));
byte[] newIv2 = c.getIV(); c.init(Cipher.WRAP_MODE, KEY, new IvParameterSpec(defIv));
if (newIv == newIv2) { c.init(Cipher.UNWRAP_MODE, KEY, new IvParameterSpec(defIv));
throw new RuntimeException("Failed getIV copy check");
if (!Arrays.equals(defIv, c.getIV())) {
throw new RuntimeException("Failed set iv check");
}
} }
} }
public static void main(String[] argv) throws Exception { public static void main(String[] argv) throws Exception {
byte[] data = DATA_128; byte[] data = DATA_32;
String ALGO = "AES/KW/PKCS5Padding"; SecretKey aes256 = new SecretKeySpec(DATA_32, "AES");
System.out.println("Testing " + ALGO); SecretKey any256 = new SecretKeySpec(DATA_32, "ANY");
Cipher c = Cipher.getInstance(ALGO, "SunJCE"); PrivateKey priv = KeyPairGenerator.getInstance
("RSA", "SunRsaSign").generateKeyPair().getPrivate();
// test all possible pad lengths, i.e. 1 - 16 String[] algos = {
for (int i = 1; i <= MAX_KW_PKCS5PAD_LEN; i++) { "AES/KW/PKCS5Padding", "AES/KW/NoPadding", "AES/KWP/NoPadding"
testEnc(c, data, data.length - i, KW_IV_LEN, MAX_KW_PKCS5PAD_LEN); };
testWrap(c, data, data.length - i, KW_IV_LEN, MAX_KW_PKCS5PAD_LEN);
for (String a : algos) {
System.out.println("Testing " + a);
Cipher c = Cipher.getInstance(a, "SunJCE");
int blkSize = c.getBlockSize();
// set the default based on AES/KWP/NoPadding, the other two
// override as needed
int startLen = data.length - blkSize;
int inc = 1;
IvParameterSpec[] ivs = new IvParameterSpec[] { null };
int padLen = MAX_KWP_PAD_LEN;
Key[] keys = new Key[] { aes256, any256, priv };
int ivLen = KWP_IV_LEN;
boolean allowCustomIv = false;
switch (a) {
case "AES/KW/PKCS5Padding":
ivs = new IvParameterSpec[] {
null, new IvParameterSpec(DATA_32, 0, KW_IV_LEN) };
padLen = MAX_KW_PKCS5PAD_LEN;
ivLen = KW_IV_LEN;
allowCustomIv = true;
break;
case "AES/KW/NoPadding":
startLen = data.length >> 1;
inc = blkSize;
ivs = new IvParameterSpec[] {
null, new IvParameterSpec(DATA_32, 0, KW_IV_LEN) };
padLen = 0;
keys = new Key[] { aes256, any256 };
ivLen = KW_IV_LEN;
allowCustomIv = true;
break;
}
// now test based on the configured arguments
testEnc(c, data, startLen, inc, ivs, padLen);
testWrap(c, keys, ivs, padLen);
testIv(c, ivLen, allowCustomIv);
} }
testIv(c);
ALGO = "AES/KW/NoPadding";
System.out.println("Testing " + ALGO);
c = Cipher.getInstance(ALGO, "SunJCE");
testEnc(c, data, data.length, KW_IV_LEN, 0);
testEnc(c, data, data.length >> 1, KW_IV_LEN, 0);
testWrap(c, data, data.length, KW_IV_LEN, 0);
testWrap(c, data, data.length >> 1, KW_IV_LEN, 0);
testIv(c);
ALGO = "AES/KWP/NoPadding";
System.out.println("Testing " + ALGO);
c = Cipher.getInstance(ALGO, "SunJCE");
// test all possible pad lengths, i.e. 0 - 7
for (int i = 0; i <= MAX_KWP_PAD_LEN; i++) {
testEnc(c, data, data.length - i, KWP_IV_LEN, MAX_KWP_PAD_LEN);
testWrap(c, data, data.length - i, KWP_IV_LEN, MAX_KWP_PAD_LEN);
}
testIv(c);
System.out.println("All Tests Passed"); System.out.println("All Tests Passed");
} }
} }