8328608: Multiple NewSessionTicket support for TLS

Reviewed-by: djelinski
This commit is contained in:
Anthony Scarpino 2024-08-28 17:24:33 +00:00
parent 379f3db001
commit 0c2b175898
16 changed files with 1161 additions and 300 deletions

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2002, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2002, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -25,8 +25,12 @@
package sun.security.util;
import javax.net.ssl.SSLSession;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
import java.util.*;
import java.lang.ref.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Abstract base class and factory for caches. A cache is a key-value mapping.
@ -90,6 +94,15 @@ public abstract class Cache<K,V> {
*/
public abstract void put(K key, V value);
/**
* Add V to the cache with the option to use a QueueCacheEntry if the
* cache is configured for it. If the cache is not configured for a queue,
* V will silently add the entry directly.
*/
public void put(K key, V value, boolean canQueue) {
put(key, value);
}
/**
* Get a value from the cache.
*/
@ -137,6 +150,11 @@ public abstract class Cache<K,V> {
return new MemoryCache<>(true, size, timeout);
}
public static <K,V> Cache<K,V> newSoftMemoryQueue(int size, int timeout,
int maxQueueSize) {
return new MemoryCache<>(true, size, timeout, maxQueueSize);
}
/**
* Return a new memory cache with the specified maximum size, unlimited
* lifetime for entries, with the values held by standard references.
@ -248,13 +266,12 @@ class NullCache<K,V> extends Cache<K,V> {
class MemoryCache<K,V> extends Cache<K,V> {
private static final float LOAD_FACTOR = 0.75f;
// XXXX
// Debugging
private static final boolean DEBUG = false;
private final Map<K, CacheEntry<K,V>> cacheMap;
private int maxSize;
final private int maxQueueSize;
private long lifetime;
private long nextExpirationTime = Long.MAX_VALUE;
@ -263,18 +280,25 @@ class MemoryCache<K,V> extends Cache<K,V> {
private final ReferenceQueue<V> queue;
public MemoryCache(boolean soft, int maxSize) {
this(soft, maxSize, 0);
this(soft, maxSize, 0, 0);
}
public MemoryCache(boolean soft, int maxSize, int lifetime) {
this.maxSize = maxSize;
this.lifetime = lifetime * 1000L;
if (soft)
this.queue = new ReferenceQueue<>();
else
this.queue = null;
this(soft, maxSize, lifetime, 0);
}
cacheMap = new LinkedHashMap<>(1, LOAD_FACTOR, true);
public MemoryCache(boolean soft, int maxSize, int lifetime, int qSize) {
this.maxSize = maxSize;
this.maxQueueSize = qSize;
this.lifetime = lifetime * 1000L;
if (soft) {
this.queue = new ReferenceQueue<>();
} else {
this.queue = null;
}
// LinkedHashMap is needed for its access order. 0.75f load factor is
// default.
cacheMap = new LinkedHashMap<>(1, 0.75f, true);
}
/**
@ -338,6 +362,10 @@ class MemoryCache<K,V> extends Cache<K,V> {
cnt++;
} else if (nextExpirationTime > entry.getExpirationTime()) {
nextExpirationTime = entry.getExpirationTime();
// If this is a queue, check for some expired entries
if (entry instanceof QueueCacheEntry<K,V> qe) {
qe.getQueue().removeIf(e -> !e.isValid(time));
}
}
}
if (DEBUG) {
@ -367,18 +395,60 @@ class MemoryCache<K,V> extends Cache<K,V> {
cacheMap.clear();
}
public synchronized void put(K key, V value) {
public void put(K key, V value) {
put(key, value, false);
}
/**
* This puts an entry into the cacheMap.
*
* If canQueue is true, V will be added using a QueueCacheEntry which
* is added to cacheMap. If false, V is added to the cacheMap directly.
* The caller must keep a consistent canQueue value, mixing them can
* result in a queue being replaced with a single entry.
*
* This method is synchronized to avoid multiple QueueCacheEntry
* overwriting the same key.
*
* @param key key to the cacheMap
* @param value value to be stored
* @param canQueue can the value be put into a QueueCacheEntry
*/
public synchronized void put(K key, V value, boolean canQueue) {
emptyQueue();
long expirationTime = (lifetime == 0) ? 0 :
System.currentTimeMillis() + lifetime;
long expirationTime =
(lifetime == 0) ? 0 : System.currentTimeMillis() + lifetime;
if (expirationTime < nextExpirationTime) {
nextExpirationTime = expirationTime;
}
CacheEntry<K,V> newEntry = newEntry(key, value, expirationTime, queue);
CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
if (oldEntry != null) {
oldEntry.invalidate();
return;
if (maxQueueSize == 0 || !canQueue) {
CacheEntry<K,V> oldEntry = cacheMap.put(key, newEntry);
if (oldEntry != null) {
oldEntry.invalidate();
}
} else {
CacheEntry<K, V> entry = cacheMap.get(key);
switch (entry) {
case QueueCacheEntry<K, V> qe -> {
qe.putValue(newEntry);
if (DEBUG) {
System.out.println("QueueCacheEntry= " + qe);
final AtomicInteger i = new AtomicInteger(1);
qe.queue.stream().forEach(e ->
System.out.println(i.getAndIncrement() + "= " + e));
}
}
case null, default ->
cacheMap.put(key, new QueueCacheEntry<>(key, newEntry,
expirationTime, maxQueueSize));
}
if (DEBUG) {
System.out.println("Cache entry added: key=" +
key.toString() + ", class=" +
(entry != null ? entry.getClass().getName() : null));
}
}
if (maxSize > 0 && cacheMap.size() > maxSize) {
expungeExpiredEntries();
@ -401,25 +471,37 @@ class MemoryCache<K,V> extends Cache<K,V> {
if (entry == null) {
return null;
}
long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
if (!entry.isValid(time)) {
if (lifetime > 0 && !entry.isValid(System.currentTimeMillis())) {
cacheMap.remove(key);
if (DEBUG) {
System.out.println("Ignoring expired entry");
}
cacheMap.remove(key);
return null;
}
// If the value is a queue, return a queue entry.
if (entry instanceof QueueCacheEntry<K, V> qe) {
V result = qe.getValue(lifetime);
if (qe.isEmpty()) {
removeImpl(key);
}
return result;
}
return entry.getValue();
}
public synchronized void remove(Object key) {
emptyQueue();
removeImpl(key);
}
private void removeImpl(Object key) {
CacheEntry<K,V> entry = cacheMap.remove(key);
if (entry != null) {
entry.invalidate();
}
}
public synchronized V pull(Object key) {
emptyQueue();
CacheEntry<K,V> entry = cacheMap.remove(key);
@ -550,9 +632,8 @@ class MemoryCache<K,V> extends Cache<K,V> {
}
}
private static class SoftCacheEntry<K,V>
extends SoftReference<V>
implements CacheEntry<K,V> {
private static class SoftCacheEntry<K,V> extends SoftReference<V>
implements CacheEntry<K,V> {
private K key;
private long expirationTime;
@ -589,6 +670,116 @@ class MemoryCache<K,V> extends Cache<K,V> {
key = null;
expirationTime = -1;
}
@Override
public String toString() {
if (get() instanceof SSLSession se)
return HexFormat.of().formatHex(se.getId());
return super.toString();
}
}
}
/**
* This CacheEntry<K,V> type allows multiple V entries to be stored in
* one key in the cacheMap.
*
* This implementation is need for TLS clients that receive multiple
* PSKs or NewSessionTickets for server resumption.
*/
private static class QueueCacheEntry<K,V> implements CacheEntry<K,V> {
// Limit the number of queue entries.
private final int MAXQUEUESIZE;
final boolean DEBUG = false;
private K key;
private long expirationTime;
final Queue<CacheEntry<K,V>> queue = new ConcurrentLinkedQueue<>();
QueueCacheEntry(K key, CacheEntry<K,V> entry, long expirationTime,
int maxSize) {
this.key = key;
this.expirationTime = expirationTime;
this.MAXQUEUESIZE = maxSize;
queue.add(entry);
}
public K getKey() {
return key;
}
public V getValue() {
return getValue(0);
}
public V getValue(long lifetime) {
long time = (lifetime == 0) ? 0 : System.currentTimeMillis();
do {
var entry = queue.poll();
if (entry == null) {
return null;
}
if (entry.isValid(time)) {
return entry.getValue();
}
entry.invalidate();
} while (!queue.isEmpty());
return null;
}
public long getExpirationTime() {
return expirationTime;
}
public void setExpirationTime(long time) {
expirationTime = time;
}
public void putValue(CacheEntry<K,V> entry) {
if (DEBUG) {
System.out.println("Added to queue (size=" + queue.size() +
"): " + entry.getKey().toString() + ", " + entry);
}
// Update the cache entry's expiration time to the latest entry.
// The getValue() calls will remove expired tickets.
expirationTime = entry.getExpirationTime();
// Limit the number of queue entries, removing the oldest.
if (queue.size() >= MAXQUEUESIZE) {
queue.remove();
}
queue.add(entry);
}
public boolean isValid(long currentTime) {
boolean valid = (currentTime <= expirationTime) && !queue.isEmpty();
if (!valid) {
invalidate();
}
return valid;
}
public boolean isValid() {
return isValid(System.currentTimeMillis());
}
public void invalidate() {
clear();
key = null;
expirationTime = -1;
}
public void clear() {
queue.forEach(CacheEntry::invalidate);
queue.clear();
}
public boolean isEmpty() {
return queue.isEmpty();
}
public Queue<CacheEntry<K,V>> getQueue() {
return queue;
}
}
}