8328608: Multiple NewSessionTicket support for TLS

Reviewed-by: djelinski
This commit is contained in:
Anthony Scarpino 2024-08-28 17:24:33 +00:00
parent 379f3db001
commit 0c2b175898
16 changed files with 1161 additions and 300 deletions

View file

@ -1139,7 +1139,9 @@ final class Finished {
//
// produce
NewSessionTicket.t13PosthandshakeProducer.produce(shc);
if (SSLConfiguration.serverNewSessionTicketCount > 0) {
NewSessionTicket.t13PosthandshakeProducer.produce(shc);
}
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -25,11 +25,11 @@
package sun.security.ssl;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Locale;
import javax.crypto.SecretKey;
import javax.net.ssl.SSLHandshakeException;
@ -118,11 +118,6 @@ final class NewSessionTicket {
this.ticket = Record.getBytes16(m);
}
@Override
public SSLHandshake handshakeType() {
return NEW_SESSION_TICKET;
}
@Override
public int messageLength() {
return 4 + // ticketLifetime
@ -221,11 +216,6 @@ final class NewSessionTicket {
this.extensions = new SSLExtensions(this, m, supportedExtensions);
}
@Override
public SSLHandshake handshakeType() {
return NEW_SESSION_TICKET;
}
int getTicketAgeAdd() {
return ticketAgeAdd;
}
@ -301,7 +291,7 @@ final class NewSessionTicket {
"tls13 resumption".getBytes(), nonce, hashAlg.hashLength);
return hkdf.expand(resumptionMasterSecret, hkdfInfo,
hashAlg.hashLength, "TlsPreSharedKey");
} catch (GeneralSecurityException gse) {
} catch (GeneralSecurityException gse) {
throw new SSLHandshakeException("Could not derive PSK", gse);
}
}
@ -332,8 +322,7 @@ final class NewSessionTicket {
// Is this session resumable?
if (!hc.handshakeSession.isRejoinable()) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
SSLLogger.fine("No session ticket produced: " +
"session is not resumable");
}
@ -351,8 +340,7 @@ final class NewSessionTicket {
if (pkemSpec == null ||
!pkemSpec.contains(PskKeyExchangeMode.PSK_DHE_KE)) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
SSLLogger.fine("No session ticket produced: " +
"client does not support psk_dhe_ke");
}
@ -363,8 +351,7 @@ final class NewSessionTicket {
// using an allowable PSK exchange key mode.
if (!hc.handshakeSession.isPSKable()) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
SSLLogger.fine("No session ticket produced: " +
"No session ticket allowed in this session");
}
@ -375,76 +362,113 @@ final class NewSessionTicket {
// get a new session ID
SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
hc.sslContext.engineGetServerSessionContext();
SessionId newId = new SessionId(true,
hc.sslContext.getSecureRandom());
SecretKey resumptionMasterSecret =
hc.handshakeSession.getResumptionMasterSecret();
if (resumptionMasterSecret == null) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
"no resumption secret");
}
return null;
}
// construct the PSK and handshake message
BigInteger nonce = hc.handshakeSession.incrTicketNonceCounter();
byte[] nonceArr = nonce.toByteArray();
SecretKey psk = derivePreSharedKey(
hc.negotiatedCipherSuite.hashAlg,
resumptionMasterSecret, nonceArr);
int sessionTimeoutSeconds = sessionCache.getSessionTimeout();
if (sessionTimeoutSeconds > MAX_TICKET_LIFETIME) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"No session ticket produced: " +
"session timeout");
SSLLogger.fine("No session ticket produced: " +
"session timeout is too long");
}
return null;
}
NewSessionTicketMessage nstm = null;
// Send NewSessionTickets to the client based
if (SSLConfiguration.serverNewSessionTicketCount > 0) {
int i = 0;
NewSessionTicketMessage nstm;
while (i < SSLConfiguration.serverNewSessionTicketCount) {
nstm = generateNST(hc, sessionCache);
if (nstm == null) {
break;
}
nstm.write(hc.handshakeOutput);
i++;
}
hc.handshakeOutput.flush();
}
/*
* With large NST counts, a client that quickly closes after
* TLS Finished completes can cause SocketExceptions such as:
* Windows servers read-side throwing SocketException:
* "An established connection was aborted by the software in
* your host machine", which relates to error WSAECONNABORTED.
* A SocketException caused by a "broken pipe" has been observed on
* other systems.
* These are very unlikely situations when client and server are on
* different machines.
*
* RFC 8446 does not put requirements when an NST needs to be
* sent, but it should be sent very soon after TLS Finished for
* clients that will quickly resume to create more sessions.
* TLS 1.3 is different from TLS 1.2, there is more data the client
* should be aware of
*/
// See note on TransportContext.needHandshakeFinishedStatus.
//
// Reset the needHandshakeFinishedStatus flag. The delivery
// of this post-handshake message will indicate the FINISHED
// handshake status. It is not needed to have a follow-on
// SSLEngine.wrap() any longer.
if (hc.conContext.needHandshakeFinishedStatus) {
hc.conContext.needHandshakeFinishedStatus = false;
}
// clean the post handshake context
hc.conContext.finishPostHandshake();
// The message has been delivered.
return null;
}
private NewSessionTicketMessage generateNST(HandshakeContext hc,
SSLSessionContextImpl sessionCache) throws IOException {
NewSessionTicketMessage nstm;
SessionId newId = new SessionId(true,
hc.sslContext.getSecureRandom());
// construct the PSK and handshake message
byte[] nonce = hc.handshakeSession.incrTicketNonceCounter();
SSLSessionImpl sessionCopy =
new SSLSessionImpl(hc.handshakeSession, newId);
sessionCopy.setPreSharedKey(psk);
new SSLSessionImpl(hc.handshakeSession, newId);
sessionCopy.setPreSharedKey(derivePreSharedKey(
hc.negotiatedCipherSuite.hashAlg,
hc.handshakeSession.getResumptionMasterSecret(), nonce));
sessionCopy.setPskIdentity(newId.getId());
// If a stateless ticket is allowed, attempt to make one
if (hc.statelessResumption &&
hc.handshakeSession.isStatelessable()) {
nstm = new T13NewSessionTicketMessage(hc,
sessionTimeoutSeconds,
sessionCache.getSessionTimeout(),
hc.sslContext.getSecureRandom(),
nonceArr,
nonce,
new SessionTicketSpec().encrypt(hc, sessionCopy));
// If ticket construction failed, switch to session cache
if (!nstm.isValid()) {
hc.statelessResumption = false;
} else {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"Produced NewSessionTicket stateless " +
SSLLogger.fine("Produced NewSessionTicket stateless " +
"post-handshake message", nstm);
}
}
return nstm;
}
// If a session cache ticket is being used, make one
if (!hc.statelessResumption ||
!hc.handshakeSession.isStatelessable()) {
nstm = new T13NewSessionTicketMessage(hc, sessionTimeoutSeconds,
hc.sslContext.getSecureRandom(), nonceArr,
newId.getId());
nstm = new T13NewSessionTicketMessage(hc,
sessionCache.getSessionTimeout(),
hc.sslContext.getSecureRandom(), nonce,
newId.getId());
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"Produced NewSessionTicket post-handshake message",
nstm);
SSLLogger.fine("Produced NewSessionTicket " +
"post-handshake message", nstm);
}
// create and cache the new session
@ -453,29 +477,13 @@ final class NewSessionTicket {
hc.handshakeSession.addChild(sessionCopy);
sessionCopy.setTicketAgeAdd(nstm.getTicketAgeAdd());
sessionCache.put(sessionCopy);
return nstm;
}
// Output the handshake message.
if (nstm != null) {
// should never be null
nstm.write(hc.handshakeOutput);
hc.handshakeOutput.flush();
// See note on TransportContext.needHandshakeFinishedStatus.
//
// Reset the needHandshakeFinishedStatus flag. The delivery
// of this post-handshake message will indicate the FINISHED
// handshake status. It is not needed to have a follow-on
// SSLEngine.wrap() any longer.
if (hc.conContext.needHandshakeFinishedStatus) {
hc.conContext.needHandshakeFinishedStatus = false;
}
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("No NewSessionTicket created");
}
// clean the post handshake context
hc.conContext.finishPostHandshake();
// The message has been delivered.
return null;
}
}
@ -497,8 +505,9 @@ final class NewSessionTicket {
ServerHandshakeContext shc = (ServerHandshakeContext)context;
// Is this session resumable?
if (!shc.handshakeSession.isRejoinable()) {
// Are new tickets allowed? If so, is this session resumable?
if (SSLConfiguration.serverNewSessionTicketCount == 0 ||
!shc.handshakeSession.isRejoinable()) {
return null;
}
@ -578,7 +587,6 @@ final class NewSessionTicket {
"Discarding NewSessionTicket with lifetime " +
nstm.ticketLifetime, nstm);
}
sessionCache.remove(hc.handshakeSession.getSessionId());
return;
}
@ -619,13 +627,19 @@ final class NewSessionTicket {
sessionCopy.setPreSharedKey(psk);
sessionCopy.setTicketAgeAdd(nstm.getTicketAgeAdd());
sessionCopy.setPskIdentity(nstm.ticket);
sessionCache.put(sessionCopy);
sessionCache.put(sessionCopy, sessionCopy.isPSK());
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("MultiNST PSK (Server): " +
Utilities.toHexString(Arrays.copyOf(nstm.ticket, 16)));
}
// clean the post handshake context
hc.conContext.finishPostHandshake();
}
}
/* TLS 1.2 spec does not specify multiple NST behavior.*/
private static final
class T12NewSessionTicketConsumer implements SSLConsumer {
// Prevent instantiation of this class.
@ -674,8 +688,7 @@ final class NewSessionTicket {
hc.handshakeSession.setPskIdentity(nstm.ticket);
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Consuming NewSessionTicket\n" +
nstm.toString());
SSLLogger.fine("Consuming NewSessionTicket\n" + nstm);
}
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -699,11 +699,13 @@ final class PreSharedKeyExtension {
//The session cannot be used again. Remove it from the cache.
SSLSessionContextImpl sessionCache = (SSLSessionContextImpl)
chc.sslContext.engineGetClientSessionContext();
sessionCache.remove(chc.resumingSession.getSessionId());
sessionCache.remove(chc.resumingSession.getSessionId(), true);
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"Found resumable session. Preparing PSK message.");
SSLLogger.fine(
"MultiNST PSK (Client): " + Utilities.toHexString(Arrays.copyOf(chc.pskIdentity, 16)));
}
List<PskIdentity> identities = new ArrayList<>();

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -121,6 +121,11 @@ final class SSLConfiguration implements Cloneable {
static final boolean enableDtlsResumeCookie = Utilities.getBooleanProperty(
"jdk.tls.enableDtlsResumeCookie", true);
// Number of NewSessionTickets that will be sent by the server.
static final int serverNewSessionTicketCount;
// Default for NewSessionTickets
static final int SERVER_NST_DEFAULT = 1;
// Is the extended_master_secret extension supported?
static {
boolean supportExtendedMasterSecret = Utilities.getBooleanProperty(
@ -182,7 +187,7 @@ final class SSLConfiguration implements Cloneable {
* - Otherwise it is set to a default value of 10.
*/
Integer inboundServerLen = GetIntegerAction.privilegedGetProperty(
"jdk.tls.client.maxInboundCertificateChainLength");
"jdk.tls.client.maxInboundCertificateChainLength");
// Default for jdk.tls.client.maxInboundCertificateChainLength is 10
if (inboundServerLen == null || inboundServerLen < 0) {
@ -191,6 +196,33 @@ final class SSLConfiguration implements Cloneable {
} else {
maxInboundServerCertChainLen = inboundServerLen;
}
/*
* jdk.tls.server.newSessionTicketCount system property
* Sets the number of NewSessionTickets sent to a TLS 1.3 resumption
* client. The value must be between 0 and 10. Default is defined by
* SERVER_NST_DEFAULT.
*/
Integer nstServerCount = GetIntegerAction.privilegedGetProperty(
"jdk.tls.server.newSessionTicketCount");
if (nstServerCount == null || nstServerCount < 0 ||
nstServerCount > 10) {
serverNewSessionTicketCount = SERVER_NST_DEFAULT;
if (nstServerCount != null && SSLLogger.isOn &&
SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"jdk.tls.server.newSessionTicketCount defaults to " +
SERVER_NST_DEFAULT + " as the property was not " +
"between 0 and 10");
}
} else {
serverNewSessionTicketCount = nstServerCount;
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine(
"jdk.tls.server.newSessionTicketCount set to " +
serverNewSessionTicketCount);
}
}
}
SSLConfiguration(SSLContextImpl sslContext, boolean isClientMode) {

View file

@ -416,11 +416,12 @@ final class SSLEngineImpl extends SSLEngine implements SSLTransport {
HandshakeStatus currentHandshakeStatus) throws IOException {
// Don't bother to kickstart if handshaking is in progress, or if the
// connection is not duplex-open.
if ((conContext.handshakeContext == null) &&
conContext.protocolVersion.useTLS13PlusSpec() &&
!conContext.isOutboundClosed() &&
!conContext.isInboundClosed() &&
!conContext.isBroken) {
if (SSLConfiguration.serverNewSessionTicketCount > 0 &&
conContext.handshakeContext == null &&
conContext.protocolVersion.useTLS13PlusSpec() &&
!conContext.isOutboundClosed() &&
!conContext.isInboundClosed() &&
!conContext.isBroken) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.finest("trigger NST");
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 1999, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1999, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -63,6 +63,7 @@ import sun.security.util.Cache;
final class SSLSessionContextImpl implements SSLSessionContext {
private static final int DEFAULT_MAX_CACHE_SIZE = 20480;
private static final int DEFAULT_MAX_QUEUE_SIZE = 10;
// Default lifetime of a session. 24 hours
static final int DEFAULT_SESSION_TIMEOUT = 86400;
@ -87,14 +88,17 @@ final class SSLSessionContextImpl implements SSLSessionContext {
cacheLimit = getDefaults(server); // default cache size
// use soft reference
sessionCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
sessionHostPortCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
if (server) {
sessionCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
sessionHostPortCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
keyHashMap = new ConcurrentHashMap<>();
// Should be "randomly generated" according to RFC 5077,
// but doesn't necessarily has to be a true random number.
// but doesn't necessarily have to be a true random number.
currentKeyID = new Random(System.nanoTime()).nextInt();
} else {
sessionCache = Cache.newSoftMemoryCache(cacheLimit, timeout);
sessionHostPortCache = Cache.newSoftMemoryQueue(cacheLimit, timeout,
DEFAULT_MAX_QUEUE_SIZE);
keyHashMap = Map.of();
}
}
@ -277,12 +281,22 @@ final class SSLSessionContextImpl implements SSLSessionContext {
// time it created, which is a little longer than the expected. So
// please do check isTimedout() while getting entry from the cache.
void put(SSLSessionImpl s) {
put(s, false);
}
/**
* Put an entry in the cache
* @param s SSLSessionImpl entry to be stored
* @param canQueue True if multiple entries may exist under one
* session entry.
*/
void put(SSLSessionImpl s, boolean canQueue) {
sessionCache.put(s.getSessionId(), s);
// If no hostname/port info is available, don't add this one.
if ((s.getPeerHost() != null) && (s.getPeerPort() != -1)) {
sessionHostPortCache.put(
getKey(s.getPeerHost(), s.getPeerPort()), s);
getKey(s.getPeerHost(), s.getPeerPort()), s, canQueue);
}
s.setContext(this);
@ -290,11 +304,17 @@ final class SSLSessionContextImpl implements SSLSessionContext {
// package-private method, remove a cached SSLSession
void remove(SessionId key) {
remove(key, false);
}
void remove(SessionId key, boolean isClient) {
SSLSessionImpl s = sessionCache.get(key);
if (s != null) {
sessionCache.remove(key);
sessionHostPortCache.remove(
// A client keeps the cache entry for queued NST resumption.
if (!isClient) {
sessionHostPortCache.remove(
getKey(s.getPeerHost(), s.getPeerPort()));
}
}
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 1996, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1996, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -38,7 +38,6 @@ import java.util.Arrays;
import java.util.Queue;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
@ -132,7 +131,11 @@ final class SSLSessionImpl extends ExtendedSSLSession {
private final List<SNIServerName> requestedServerNames;
// Counter used to create unique nonces in NewSessionTicket
private BigInteger ticketNonceCounter = BigInteger.ONE;
private byte ticketNonceCounter = 1;
// This boolean is true when a new set of NewSessionTickets are needed after
// the initial ones sent after the handshake.
boolean updateNST = false;
// The endpoint identification algorithm used to check certificates
// in this session.
@ -492,7 +495,7 @@ final class SSLSessionImpl extends ExtendedSSLSession {
// Length of pre-shared key algorithm (one byte)
i = buf.get();
b = new byte[i];
buf.get(b, 0 , i);
buf.get(b, 0, i);
String alg = new String(b);
// Get length of encoding
i = Short.toUnsignedInt(buf.getShort());
@ -501,8 +504,13 @@ final class SSLSessionImpl extends ExtendedSSLSession {
buf.get(b);
this.preSharedKey = new SecretKeySpec(b, alg);
// Get identity len
this.pskIdentity = new byte[buf.get()];
buf.get(pskIdentity);
i = buf.get();
if (i > 0) {
this.pskIdentity = new byte[buf.get()];
buf.get(pskIdentity);
} else {
this.pskIdentity = null;
}
break;
default:
throw new SSLException("Failed local certs of session.");
@ -715,14 +723,12 @@ final class SSLSessionImpl extends ExtendedSSLSession {
this.pskIdentity = pskIdentity;
}
BigInteger incrTicketNonceCounter() {
BigInteger result = ticketNonceCounter;
ticketNonceCounter = ticketNonceCounter.add(BigInteger.ONE);
return result;
byte[] incrTicketNonceCounter() {
return new byte[] {ticketNonceCounter++};
}
boolean isPSKable() {
return (ticketNonceCounter.compareTo(BigInteger.ZERO) > 0);
return (ticketNonceCounter > 0);
}
/**
@ -781,6 +787,10 @@ final class SSLSessionImpl extends ExtendedSSLSession {
return pskIdentity;
}
public boolean isPSK() {
return (pskIdentity != null && pskIdentity.length > 0);
}
void setPeerCertificates(X509Certificate[] peer) {
if (peerCerts == null) {
peerCerts = peer;
@ -1230,7 +1240,6 @@ final class SSLSessionImpl extends ExtendedSSLSession {
* sessions can be shared across different protection domains.
*/
private final ConcurrentHashMap<SecureKey, Object> boundValues;
boolean updateNST;
/**
* Assigns a session value. Session change events are given if

View file

@ -1321,7 +1321,6 @@ public final class SSLSocketImpl
}
// Check if NewSessionTicket PostHandshake message needs to be sent
if (conContext.conSession.updateNST) {
conContext.conSession.updateNST = false;
tryNewSessionTicket();
}
}
@ -1556,15 +1555,17 @@ public final class SSLSocketImpl
private void tryNewSessionTicket() throws IOException {
// Don't bother to kickstart if handshaking is in progress, or if the
// connection is not duplex-open.
if (!conContext.sslConfig.isClientMode &&
conContext.protocolVersion.useTLS13PlusSpec() &&
conContext.handshakeContext == null &&
!conContext.isOutboundClosed() &&
!conContext.isInboundClosed() &&
!conContext.isBroken) {
if (SSLConfiguration.serverNewSessionTicketCount > 0 &&
!conContext.sslConfig.isClientMode &&
conContext.protocolVersion.useTLS13PlusSpec() &&
conContext.handshakeContext == null &&
!conContext.isOutboundClosed() &&
!conContext.isInboundClosed() &&
!conContext.isBroken) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.finest("trigger new session ticket");
}
conContext.conSession.updateNST = false;
NewSessionTicket.t13PosthandshakeProducer.produce(
new PostHandshakeContext(conContext));
}

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2002, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2002, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -25,8 +25,12 @@
package sun.security.util;
import javax.net.ssl.SSLSession;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
import java.util.*;
import java.lang.ref.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Abstract base class and factory for caches. A cache is a key-value mapping.
@ -90,6 +94,15 @@ public abstract class Cache<K,V> {
*/
public abstract void put(K key, V value);
/**
* Add V to the cache with the option to use a QueueCacheEntry if the
* cache is configured for it. If the cache is not configured for a queue,
* V will silently add the entry directly.
*/
public void put(K key, V value, boolean canQueue) {
put(key, value);
}
/**
* Get a value from the cache.
*/
@ -137,6 +150,11 @@ public abstract class Cache<K,V> {
return new MemoryCache<>(true, size, timeout);
}
public static <K,V> Cache<K,V> newSoftMemoryQueue(int size, int timeout,
int maxQueueSize) {
return new MemoryCache<>(true, size, timeout, maxQueueSize);
}
/**
* Return a new memory cache with the specified maximum size, unlimited
* lifetime for entries, with the values held by standard references.
@ -248,13 +266,12 @@ class NullCache<K,V> extends Cache<K,V> {
class MemoryCache<K,V> extends Cache<K,V> {
private static final float LOAD_FACTOR = 0.75f;
// XXXX
// Debugging
private static final boolean DEBUG = false;
private final Map<K, CacheEntry<K,V>> cacheMap;
private int maxSize;
final private int maxQueueSize;
private long lifetime;
private long nextExpirationTime = Long.MAX_VALUE;
@ -263,18 +280,25 @@ class MemoryCache<K,V> extends Cache<K,V> {
private final ReferenceQueue<V> queue;
public MemoryCache(boolean soft, int maxSize) {
this(soft, maxSize, 0);
this(soft, maxSize, 0, 0);
}
public MemoryCache(boolean soft, int maxSize, int lifetime) {
this.maxSize = maxSize;
this.lifetime = lifetime * 1000L;
if (soft)
this.queue = new ReferenceQueue<>();
else
this.queue = null;
this(soft, maxSize, lifetime, 0);
}
cacheMap = new LinkedHashMap<>(1, LOAD_FACTOR, true);
public MemoryCache(boolean soft, int maxSize, int lifetime, int qSize) {
this.maxSize = maxSize;
this.maxQueueSize = qSize;
this.lifetime = lifetime * 1000L;
if (soft) {
this.queue = new ReferenceQueue<>();
} else {
this.queue = null;
}
// LinkedHashMap is needed for its access order. 0.75f load factor is
// default.
cacheMap = new LinkedHashMap<>(1, 0.75f, true);
}
/**
@ -338,6 +362,10 @@ class MemoryCache<K,V> extends Cache<K,V> {
cnt++;
} else if (nextExpirationTime > entry.getExpirationTime()) {
nextExpirationTime = entry.getExpirationTime();
// If this is a queue, check for some expired entries
if (entry instanceof QueueCacheEntry<K,V> qe) {
qe.getQueue().removeIf(e -> !e.isValid(time));
}
}
}
if (DEBUG) {
@ -367,18 +395,60 @@ class MemoryCache<K,V> extends Cache<K,V> {
cacheMap.clear();
}
public synchronized void put(K key, V value) {
public void put(K key, V value) {
put(key, value, false);
}
/**
* This puts an entry into the cacheMap.
*
* If canQueue is true, V will be added using a QueueCacheEntry which
* is added to cacheMap. If false, V is added to the cacheMap directly.
* The caller must keep a consistent canQueue value, mixing them can
* result in a queue being replaced with a single entry.
*
* This method is synchronized to avoid multiple QueueCacheEntry
* overwriting the same key.
*
* @param key key to the cacheMap
* @param value value to be stored
* @param canQueue can the value be put into a QueueCacheEntry
*/
public synchronized void put(K key, V value, boolean canQueue) {
emptyQueue();
long expirationTime = (lifetime == 0) ? 0 :
System.currentTimeMillis() + lifetime;
long expirationTime =
(lifetime == 0) ? 0 : System.currentTimeMillis() + lifetime;
if (expirationTime < nextExpirationTime) {
nextExpirationTime = expirationTime;
}
CacheEntry<K,V> newEntry = newEntry(key, value, expirationTime, queue);
CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
if (oldEntry != null) {
oldEntry.invalidate();
return;
if (maxQueueSize == 0 || !canQueue) {
CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
if (oldEntry != null) {
oldEntry.invalidate();
}
} else {
CacheEntry<K, V> entry = cacheMap.get(key);
switch (entry) {
case QueueCacheEntry<K, V> qe -> {
qe.putValue(newEntry);
if (DEBUG) {
System.out.println("QueueCacheEntry= " + qe);
final AtomicInteger i = new AtomicInteger(1);
qe.queue.stream().forEach(e ->
System.out.println(i.getAndIncrement() + "= " + e));
}
}
case null, default ->
cacheMap.put(key, new QueueCacheEntry<>(key, newEntry,
expirationTime, maxQueueSize));
}
if (DEBUG) {
System.out.println("Cache entry added: key=" +
key.toString() + ", class=" +
(entry != null ? entry.getClass().getName() : null));
}
}
if (maxSize > 0 && cacheMap.size() > maxSize) {
expungeExpiredEntries();
@ -401,25 +471,37 @@ class MemoryCache<K,V> extends Cache<K,V> {
if (entry == null) {
return null;
}
long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
if (!entry.isValid(time)) {
if (lifetime > 0 && !entry.isValid(System.currentTimeMillis())) {
cacheMap.remove(key);
if (DEBUG) {
System.out.println("Ignoring expired entry");
}
cacheMap.remove(key);
return null;
}
// If the value is a queue, return a queue entry.
if (entry instanceof QueueCacheEntry<K, V> qe) {
V result = qe.getValue(lifetime);
if (qe.isEmpty()) {
removeImpl(key);
}
return result;
}
return entry.getValue();
}
public synchronized void remove(Object key) {
emptyQueue();
removeImpl(key);
}
private void removeImpl(Object key) {
CacheEntry<K,V> entry = cacheMap.remove(key);
if (entry != null) {
entry.invalidate();
}
}
public synchronized V pull(Object key) {
emptyQueue();
CacheEntry<K,V> entry = cacheMap.remove(key);
@ -550,9 +632,8 @@ class MemoryCache<K,V> extends Cache<K,V> {
}
}
private static class SoftCacheEntry<K,V>
extends SoftReference<V>
implements CacheEntry<K,V> {
private static class SoftCacheEntry<K,V> extends SoftReference<V>
implements CacheEntry<K,V> {
private K key;
private long expirationTime;
@ -589,6 +670,116 @@ class MemoryCache<K,V> extends Cache<K,V> {
key = null;
expirationTime = -1;
}
@Override
public String toString() {
if (get() instanceof SSLSession se)
return HexFormat.of().formatHex(se.getId());
return super.toString();
}
}
}
/**
* This CacheEntry<K,V> type allows multiple V entries to be stored in
* one key in the cacheMap.
*
* This implementation is need for TLS clients that receive multiple
* PSKs or NewSessionTickets for server resumption.
*/
private static class QueueCacheEntry<K,V> implements CacheEntry<K,V> {
// Limit the number of queue entries.
private final int MAXQUEUESIZE;
final boolean DEBUG = false;
private K key;
private long expirationTime;
final Queue<CacheEntry<K,V>> queue = new ConcurrentLinkedQueue<>();
QueueCacheEntry(K key, CacheEntry<K,V> entry, long expirationTime,
int maxSize) {
this.key = key;
this.expirationTime = expirationTime;
this.MAXQUEUESIZE = maxSize;
queue.add(entry);
}
public K getKey() {
return key;
}
public V getValue() {
return getValue(0);
}
public V getValue(long lifetime) {
long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
do {
var entry = queue.poll();
if (entry == null) {
return null;
}
if (entry.isValid(time)) {
return entry.getValue();
}
entry.invalidate();
} while (!queue.isEmpty());
return null;
}
public long getExpirationTime() {
return expirationTime;
}
public void setExpirationTime(long time) {
expirationTime = time;
}
public void putValue(CacheEntry<K,V> entry) {
if (DEBUG) {
System.out.println("Added to queue (size=" + queue.size() +
"): " + entry.getKey().toString() + ", " + entry);
}
// Update the cache entry's expiration time to the latest entry.
// The getValue() calls will remove expired tickets.
expirationTime = entry.getExpirationTime();
// Limit the number of queue entries, removing the oldest.
if (queue.size() >= MAXQUEUESIZE) {
queue.remove();
}
queue.add(entry);
}
public boolean isValid(long currentTime) {
boolean valid = (currentTime <= expirationTime) && !queue.isEmpty();
if (!valid) {
invalidate();
}
return valid;
}
public boolean isValid() {
return isValid(System.currentTimeMillis());
}
public void invalidate() {
clear();
key = null;
expirationTime = -1;
}
public void clear() {
queue.forEach(CacheEntry::invalidate);
queue.clear();
}
public boolean isEmpty() {
return queue.isEmpty();
}
public Queue<CacheEntry<K,V>> getQueue() {
return queue;
}
}
}