diff --git a/src/java.base/share/classes/sun/security/pkcs12/PKCS12KeyStore.java b/src/java.base/share/classes/sun/security/pkcs12/PKCS12KeyStore.java index 04e040b611b..52393f0466a 100644 --- a/src/java.base/share/classes/sun/security/pkcs12/PKCS12KeyStore.java +++ b/src/java.base/share/classes/sun/security/pkcs12/PKCS12KeyStore.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 1999, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1999, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -295,9 +295,13 @@ public final class PKCS12KeyStore extends KeyStoreSpi { * (e.g., the given password is wrong). */ public Key engineGetKey(String alias, char[] password) - throws NoSuchAlgorithmException, UnrecoverableKeyException - { + throws NoSuchAlgorithmException, UnrecoverableKeyException { Entry entry = entries.get(alias.toLowerCase(Locale.ENGLISH)); + return internalGetKey(entry, password); + } + + private Key internalGetKey(Entry entry, char[] password) + throws NoSuchAlgorithmException, UnrecoverableKeyException { Key key; if (!(entry instanceof KeyEntry)) { @@ -321,7 +325,7 @@ public final class PKCS12KeyStore extends KeyStoreSpi { try { // get the encrypted private key EncryptedPrivateKeyInfo encrInfo = - new EncryptedPrivateKeyInfo(encrBytes); + new EncryptedPrivateKeyInfo(encrBytes); encryptedKey = encrInfo.getEncryptedData(); // parse Algorithm parameters @@ -332,20 +336,20 @@ public final class PKCS12KeyStore extends KeyStoreSpi { } catch (IOException ioe) { UnrecoverableKeyException uke = - new UnrecoverableKeyException("Private key not stored as " - + "PKCS#8 EncryptedPrivateKeyInfo: " + ioe); + new UnrecoverableKeyException("Private key not stored as " + + "PKCS#8 EncryptedPrivateKeyInfo: " + ioe); uke.initCause(ioe); throw uke; } - try { + try { PBEParameterSpec pbeSpec; int ic; if (algParams != null) { try { pbeSpec = - algParams.getParameterSpec(PBEParameterSpec.class); + algParams.getParameterSpec(PBEParameterSpec.class); } catch (InvalidParameterSpecException ipse) { throw new IOException("Invalid PBE algorithm parameters"); } @@ -392,7 +396,7 @@ public final class PKCS12KeyStore extends KeyStoreSpi { if (debug != null) { debug.println("Retrieved a protected private key at alias" + - " '" + alias + "' (" + + " '" + entry.alias + "' (" + aid.getName() + " iterations: " + ic + ")"); } @@ -433,7 +437,7 @@ public final class PKCS12KeyStore extends KeyStoreSpi { if (debug != null) { debug.println("Retrieved a protected secret key at alias " + - "'" + alias + "' (" + + "'" + entry.alias + "' (" + aid.getName() + " iterations: " + ic + ")"); } @@ -450,8 +454,8 @@ public final class PKCS12KeyStore extends KeyStoreSpi { } catch (Exception e) { UnrecoverableKeyException uke = - new UnrecoverableKeyException("Get Key failed: " + - e.getMessage()); + new UnrecoverableKeyException("Get Key failed: " + + e.getMessage()); uke.initCause(e); throw uke; } @@ -471,6 +475,10 @@ public final class PKCS12KeyStore extends KeyStoreSpi { */ public Certificate[] engineGetCertificateChain(String alias) { Entry entry = entries.get(alias.toLowerCase(Locale.ENGLISH)); + return internalGetCertificateChain(entry); + } + + private Certificate[] internalGetCertificateChain(Entry entry) { if (entry instanceof PrivateKeyEntry privateKeyEntry) { if (privateKeyEntry.chain == null) { return null; @@ -478,8 +486,8 @@ public final class PKCS12KeyStore extends KeyStoreSpi { if (debug != null) { debug.println("Retrieved a " + - privateKeyEntry.chain.length + - "-certificate chain at alias '" + alias + "'"); + privateKeyEntry.chain.length + + "-certificate chain at alias '" + entry.alias + "'"); } return privateKeyEntry.chain.clone(); @@ -1013,18 +1021,19 @@ public final class PKCS12KeyStore extends KeyStoreSpi { debug.println("Removing entry at alias '" + alias + "'"); } - Entry entry = entries.get(alias.toLowerCase(Locale.ENGLISH)); - if (entry instanceof PrivateKeyEntry keyEntry) { - if (keyEntry.chain != null) { - certificateCount -= keyEntry.chain.length; + Entry entry = entries.remove(alias.toLowerCase(Locale.ENGLISH)); + if (entry != null) { + if (entry instanceof PrivateKeyEntry keyEntry) { + if (keyEntry.chain != null) { + certificateCount -= keyEntry.chain.length; + } + privateKeyCount--; + } else if (entry instanceof CertEntry) { + certificateCount--; + } else if (entry instanceof SecretKeyEntry) { + secretKeyCount--; } - privateKeyCount--; - } else if (entry instanceof CertEntry) { - certificateCount--; - } else if (entry instanceof SecretKeyEntry) { - secretKeyCount--; } - entries.remove(alias.toLowerCase(Locale.ENGLISH)); } /** @@ -1065,6 +1074,10 @@ public final class PKCS12KeyStore extends KeyStoreSpi { */ public boolean engineIsKeyEntry(String alias) { Entry entry = entries.get(alias.toLowerCase(Locale.ENGLISH)); + return internalIsKeyEntry(entry); + } + + private boolean internalIsKeyEntry(Entry entry) { return entry instanceof KeyEntry; } @@ -1075,8 +1088,13 @@ public final class PKCS12KeyStore extends KeyStoreSpi { * @return true if the entry identified by the given alias is a * trusted certificate entry, false otherwise. */ + public boolean engineIsCertificateEntry(String alias) { Entry entry = entries.get(alias.toLowerCase(Locale.ENGLISH)); + return internalIsCertificateEntry(entry); + } + + private boolean internalIsCertificateEntry(Entry entry) { return entry instanceof CertEntry certEntry && certEntry.trustedKeyUsage != null; } @@ -1306,18 +1324,14 @@ public final class PKCS12KeyStore extends KeyStoreSpi { Entry entry = entries.get(alias.toLowerCase(Locale.ENGLISH)); if (protParam == null) { - if (engineIsCertificateEntry(alias)) { - if (entry instanceof CertEntry && - ((CertEntry) entry).trustedKeyUsage != null) { - - if (debug != null) { - debug.println("Retrieved a trusted certificate at " + + if (internalIsCertificateEntry(entry)) { + if (debug != null) { + debug.println("Retrieved a trusted certificate at " + "alias '" + alias + "'"); - } - - return new KeyStore.TrustedCertificateEntry( - ((CertEntry)entry).cert, entry.attributes); } + + return new KeyStore.TrustedCertificateEntry( + ((CertEntry)entry).cert, entry.attributes); } else { throw new UnrecoverableKeyException ("requested entry requires a password"); @@ -1325,17 +1339,17 @@ public final class PKCS12KeyStore extends KeyStoreSpi { } if (protParam instanceof KeyStore.PasswordProtection) { - if (engineIsCertificateEntry(alias)) { + if (internalIsCertificateEntry(entry)) { throw new UnsupportedOperationException ("trusted certificate entries are not password-protected"); - } else if (engineIsKeyEntry(alias)) { + } else if (internalIsKeyEntry(entry)) { KeyStore.PasswordProtection pp = (KeyStore.PasswordProtection)protParam; char[] password = pp.getPassword(); - Key key = engineGetKey(alias, password); + Key key = internalGetKey(entry, password); if (key instanceof PrivateKey) { - Certificate[] chain = engineGetCertificateChain(alias); + Certificate[] chain = internalGetCertificateChain(entry); return new KeyStore.PrivateKeyEntry((PrivateKey)key, chain, entry.attributes); @@ -1345,7 +1359,7 @@ public final class PKCS12KeyStore extends KeyStoreSpi { return new KeyStore.SecretKeyEntry((SecretKey)key, entry.attributes); } - } else if (!engineIsKeyEntry(alias)) { + } else { throw new UnsupportedOperationException ("untrusted certificate entries are not " + "password-protected"); diff --git a/test/jdk/sun/security/pkcs12/GetSetEntryTest.java b/test/jdk/sun/security/pkcs12/GetSetEntryTest.java new file mode 100644 index 00000000000..2bfb92aae2f --- /dev/null +++ b/test/jdk/sun/security/pkcs12/GetSetEntryTest.java @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8327461 + * @summary engineGetEntry in PKCS12KeyStore should be thread-safe + * @library /test/lib ../../../java/security/testlibrary + * @modules java.base/sun.security.x509 + * java.base/sun.security.util + * @build CertificateBuilder + * @run main GetSetEntryTest + */ + +import java.math.BigInteger; +import java.security.cert.X509Certificate; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.KeyStore; +import java.security.spec.ECGenParameterSpec; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; +import java.util.Date; + +import sun.security.testlibrary.CertificateBuilder; + +public class GetSetEntryTest { + + public static final String TEST = "test"; + + public static void main(String[] args) throws Exception { + KeyStore ks = KeyStore.getInstance("PKCS12"); + char[] password = "password".toCharArray(); + KeyStore.PasswordProtection protParam = new KeyStore.PasswordProtection(password); + ks.load(null, null); + + CertificateBuilder cbld = new CertificateBuilder(); + KeyPairGenerator keyPairGen1 = KeyPairGenerator.getInstance("EC"); + keyPairGen1.initialize(new ECGenParameterSpec("secp256r1")); + KeyPair ecKeyPair = keyPairGen1.genKeyPair(); + + long start = System.currentTimeMillis() - TimeUnit.DAYS.toMillis(60); + long end = start + TimeUnit.DAYS.toMillis(1085); + boolean[] kuBitSettings = {true, false, false, false, false, true, + true, false, false}; + + // Set up the EC Cert + cbld.setSubjectName("CN=EC Test Cert, O=SomeCompany"). + setPublicKey(ecKeyPair.getPublic()). + setSerialNumber(new BigInteger("1")). + setValidity(new Date(start), new Date(end)). + addSubjectKeyIdExt(ecKeyPair.getPublic()). + addAuthorityKeyIdExt(ecKeyPair.getPublic()). + addBasicConstraintsExt(true, true, -1). + addKeyUsageExt(kuBitSettings); + + X509Certificate ecCert = cbld.build(null, ecKeyPair.getPrivate(), "SHA256withECDSA"); + + KeyPairGenerator keyPairGen2 = KeyPairGenerator.getInstance("RSA"); + keyPairGen2.initialize(4096); + KeyPair rsaKeyPair = keyPairGen2.genKeyPair(); + + cbld.reset(); + // Set up the RSA Cert + cbld.setSubjectName("CN=RSA Test Cert, O=SomeCompany"). + setPublicKey(rsaKeyPair.getPublic()). + setSerialNumber(new BigInteger("1")). + setValidity(new Date(start), new Date(end)). + addSubjectKeyIdExt(rsaKeyPair.getPublic()). + addAuthorityKeyIdExt(rsaKeyPair.getPublic()). + addBasicConstraintsExt(true, true, -1). + addKeyUsageExt(kuBitSettings); + + X509Certificate rsaCert = cbld.build(null, rsaKeyPair.getPrivate(), "SHA256withRSA"); + + KeyStore.PrivateKeyEntry ecEntry = new KeyStore.PrivateKeyEntry(ecKeyPair.getPrivate(), + new X509Certificate[]{ecCert}); + KeyStore.PrivateKeyEntry rsaEntry = new KeyStore.PrivateKeyEntry(rsaKeyPair.getPrivate(), + new X509Certificate[]{rsaCert}); + + test(ks, ecEntry, rsaEntry, protParam); + } + + private static final int MAX_ITERATIONS = 100; + + private static void test(KeyStore ks, KeyStore.PrivateKeyEntry ec, + KeyStore.PrivateKeyEntry rsa, + KeyStore.PasswordProtection protParam) + throws Exception { + ks.setEntry(TEST, ec, protParam); + + AtomicBoolean syncIssue = new AtomicBoolean(false); + + Thread thread = new Thread(() -> { + int iterations = 0; + while (!syncIssue.get() && iterations < MAX_ITERATIONS) { + try { + ks.setEntry(TEST, ec, protParam); + ks.setEntry(TEST, rsa, protParam); + } catch (Exception ex) { + syncIssue.set(true); + ex.printStackTrace(); + System.out.println("Test failed"); + System.exit(1); + } + iterations++; + } + }); + thread.start(); + + int iterations = 0; + while (!syncIssue.get() && iterations < MAX_ITERATIONS) { + try { + ks.getEntry(TEST, protParam); + } catch (Exception ex) { + syncIssue.set(true); + ex.printStackTrace(); + System.out.println("Test failed"); + System.exit(1); + } + iterations++; + } + + thread.join(); + + if (!syncIssue.get()) { + System.out.println("Test completed successfully"); + } + } +}