8087327: CipherStream produces new byte array on every update or doFinal operation

Changed Cipher[In/Out]putStream to allocate a buffer and reuse it

Reviewed-by: weijun
This commit is contained in:
Valerie Peng 2020-06-04 20:30:16 +00:00
parent 9a88048a05
commit b94314a0d9
2 changed files with 96 additions and 43 deletions

View file

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 1997, 2020, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
@ -92,7 +92,7 @@ public class CipherInputStream extends FilterInputStream {
/* the buffer holding data that have been processed by the cipher /* the buffer holding data that have been processed by the cipher
engine, but have not been read out */ engine, but have not been read out */
private byte[] obuffer; private byte[] obuffer = null;
// the offset pointing to the next "new" byte // the offset pointing to the next "new" byte
private int ostart = 0; private int ostart = 0;
// the offset pointing to the last "new" byte // the offset pointing to the last "new" byte
@ -100,18 +100,36 @@ public class CipherInputStream extends FilterInputStream {
// stream status // stream status
private boolean closed = false; private boolean closed = false;
/* /**
* private convenience function. * Ensure obuffer is big enough for the next update or doFinal
* operation, given the input length <code>inLen</code> (in bytes)
* The ostart and ofinish indices are reset to 0.
*
* @param inLen the input length (in bytes)
*/
private void ensureCapacity(int inLen) {
int minLen = cipher.getOutputSize(inLen);
if (obuffer == null || obuffer.length < minLen) {
obuffer = new byte[minLen];
}
ostart = 0;
ofinish = 0;
}
/**
* Private convenience function, read in data from the underlying
* input stream and process them with cipher. This method is called
* when the processed bytes inside obuffer has been exhausted.
* *
* Entry condition: ostart = ofinish * Entry condition: ostart = ofinish
* *
* Exit condition: ostart <= ofinish * Exit condition: ostart = 0 AND ostart <= ofinish
* *
* return (ofinish-ostart) (we have this many bytes for you) * return (ofinish-ostart) (we have this many bytes for you)
* return 0 (no data now, but could have more later) * return 0 (no data now, but could have more later)
* return -1 (absolutely no more data) * return -1 (absolutely no more data)
* *
* Note: Exceptions are only thrown after the stream is completely read. * Note: Exceptions are only thrown after the stream is completely read.
* For AEAD ciphers a read() of any length will internally cause the * For AEAD ciphers a read() of any length will internally cause the
* whole stream to be read fully and verify the authentication tag before * whole stream to be read fully and verify the authentication tag before
* returning decrypted data or exceptions. * returning decrypted data or exceptions.
@ -119,32 +137,30 @@ public class CipherInputStream extends FilterInputStream {
private int getMoreData() throws IOException { private int getMoreData() throws IOException {
if (done) return -1; if (done) return -1;
int readin = input.read(ibuffer); int readin = input.read(ibuffer);
if (readin == -1) { if (readin == -1) {
done = true; done = true;
ensureCapacity(0);
try { try {
obuffer = cipher.doFinal(); ofinish = cipher.doFinal(obuffer, 0);
} catch (IllegalBlockSizeException | BadPaddingException e) { } catch (IllegalBlockSizeException | BadPaddingException
obuffer = null; | ShortBufferException e) {
throw new IOException(e); throw new IOException(e);
} }
if (obuffer == null) if (ofinish == 0) {
return -1; return -1;
else { } else {
ostart = 0;
ofinish = obuffer.length;
return ofinish; return ofinish;
} }
} }
ensureCapacity(readin);
try { try {
obuffer = cipher.update(ibuffer, 0, readin); ofinish = cipher.update(ibuffer, 0, readin, obuffer, ostart);
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
obuffer = null;
throw e; throw e;
} catch (ShortBufferException e) {
throw new IOException(e);
} }
ostart = 0;
if (obuffer == null)
ofinish = 0;
else ofinish = obuffer.length;
return ofinish; return ofinish;
} }
@ -190,6 +206,7 @@ public class CipherInputStream extends FilterInputStream {
* stream is reached. * stream is reached.
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public int read() throws IOException { public int read() throws IOException {
if (ostart >= ofinish) { if (ostart >= ofinish) {
// we loop for new data as the spec says we are blocking // we loop for new data as the spec says we are blocking
@ -215,6 +232,7 @@ public class CipherInputStream extends FilterInputStream {
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
* @see java.io.InputStream#read(byte[], int, int) * @see java.io.InputStream#read(byte[], int, int)
*/ */
@Override
public int read(byte b[]) throws IOException { public int read(byte b[]) throws IOException {
return read(b, 0, b.length); return read(b, 0, b.length);
} }
@ -235,6 +253,7 @@ public class CipherInputStream extends FilterInputStream {
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
* @see java.io.InputStream#read() * @see java.io.InputStream#read()
*/ */
@Override
public int read(byte b[], int off, int len) throws IOException { public int read(byte b[], int off, int len) throws IOException {
if (ostart >= ofinish) { if (ostart >= ofinish) {
// we loop for new data as the spec says we are blocking // we loop for new data as the spec says we are blocking
@ -271,6 +290,7 @@ public class CipherInputStream extends FilterInputStream {
* @return the actual number of bytes skipped. * @return the actual number of bytes skipped.
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public long skip(long n) throws IOException { public long skip(long n) throws IOException {
int available = ofinish - ostart; int available = ofinish - ostart;
if (n > available) { if (n > available) {
@ -293,6 +313,7 @@ public class CipherInputStream extends FilterInputStream {
* without blocking. * without blocking.
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public int available() throws IOException { public int available() throws IOException {
return (ofinish - ostart); return (ofinish - ostart);
} }
@ -307,11 +328,11 @@ public class CipherInputStream extends FilterInputStream {
* *
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public void close() throws IOException { public void close() throws IOException {
if (closed) { if (closed) {
return; return;
} }
closed = true; closed = true;
input.close(); input.close();
@ -319,15 +340,15 @@ public class CipherInputStream extends FilterInputStream {
// AEAD ciphers are fully readed before closing. Any authentication // AEAD ciphers are fully readed before closing. Any authentication
// exceptions would occur while reading. // exceptions would occur while reading.
if (!done) { if (!done) {
ensureCapacity(0);
try { try {
cipher.doFinal(); cipher.doFinal(obuffer, 0);
} } catch (BadPaddingException | IllegalBlockSizeException
catch (BadPaddingException | IllegalBlockSizeException ex) { | ShortBufferException ex) {
// Catch exceptions as the rest of the stream is unused. // Catch exceptions as the rest of the stream is unused.
} }
} }
ostart = 0; obuffer = null;
ofinish = 0;
} }
/** /**
@ -339,6 +360,7 @@ public class CipherInputStream extends FilterInputStream {
* @see java.io.InputStream#mark(int) * @see java.io.InputStream#mark(int)
* @see java.io.InputStream#reset() * @see java.io.InputStream#reset()
*/ */
@Override
public boolean markSupported() { public boolean markSupported() {
return false; return false;
} }

View file

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 1997, 2020, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
@ -80,11 +80,24 @@ public class CipherOutputStream extends FilterOutputStream {
private byte[] ibuffer = new byte[1]; private byte[] ibuffer = new byte[1];
// the buffer holding data ready to be written out // the buffer holding data ready to be written out
private byte[] obuffer; private byte[] obuffer = null;
// stream status // stream status
private boolean closed = false; private boolean closed = false;
/**
* Ensure obuffer is big enough for the next update or doFinal
* operation, given the input length <code>inLen</code> (in bytes)
*
* @param inLen the input length (in bytes)
*/
private void ensureCapacity(int inLen) {
int minLen = cipher.getOutputSize(inLen);
if (obuffer == null || obuffer.length < minLen) {
obuffer = new byte[minLen];
}
}
/** /**
* *
* Constructs a CipherOutputStream from an OutputStream and a * Constructs a CipherOutputStream from an OutputStream and a
@ -123,12 +136,18 @@ public class CipherOutputStream extends FilterOutputStream {
* @param b the <code>byte</code>. * @param b the <code>byte</code>.
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public void write(int b) throws IOException { public void write(int b) throws IOException {
ibuffer[0] = (byte) b; ibuffer[0] = (byte) b;
obuffer = cipher.update(ibuffer, 0, 1); ensureCapacity(1);
if (obuffer != null) { try {
output.write(obuffer); int ostored = cipher.update(ibuffer, 0, 1, obuffer);
obuffer = null; if (ostored > 0) {
output.write(obuffer, 0, ostored);
}
} catch (ShortBufferException sbe) {
// should never happen; re-throw just in case
throw new IOException(sbe);
} }
}; };
@ -146,6 +165,7 @@ public class CipherOutputStream extends FilterOutputStream {
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
* @see javax.crypto.CipherOutputStream#write(byte[], int, int) * @see javax.crypto.CipherOutputStream#write(byte[], int, int)
*/ */
@Override
public void write(byte b[]) throws IOException { public void write(byte b[]) throws IOException {
write(b, 0, b.length); write(b, 0, b.length);
} }
@ -159,11 +179,17 @@ public class CipherOutputStream extends FilterOutputStream {
* @param len the number of bytes to write. * @param len the number of bytes to write.
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public void write(byte b[], int off, int len) throws IOException { public void write(byte b[], int off, int len) throws IOException {
obuffer = cipher.update(b, off, len); ensureCapacity(len);
if (obuffer != null) { try {
output.write(obuffer); int ostored = cipher.update(b, off, len, obuffer);
obuffer = null; if (ostored > 0) {
output.write(obuffer, 0, ostored);
}
} catch (ShortBufferException e) {
// should never happen; re-throw just in case
throw new IOException(e);
} }
} }
@ -180,11 +206,10 @@ public class CipherOutputStream extends FilterOutputStream {
* *
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public void flush() throws IOException { public void flush() throws IOException {
if (obuffer != null) { // simply call output.flush() since 'obuffer' content is always
output.write(obuffer); // written out immediately
obuffer = null;
}
output.flush(); output.flush();
} }
@ -203,20 +228,26 @@ public class CipherOutputStream extends FilterOutputStream {
* *
* @exception IOException if an I/O error occurs. * @exception IOException if an I/O error occurs.
*/ */
@Override
public void close() throws IOException { public void close() throws IOException {
if (closed) { if (closed) {
return; return;
} }
closed = true; closed = true;
ensureCapacity(0);
try { try {
obuffer = cipher.doFinal(); int ostored = cipher.doFinal(obuffer, 0);
} catch (IllegalBlockSizeException | BadPaddingException e) { if (ostored > 0) {
obuffer = null; output.write(obuffer, 0, ostored);
}
} catch (IllegalBlockSizeException | BadPaddingException
| ShortBufferException e) {
} }
obuffer = null;
try { try {
flush(); flush();
} catch (IOException ignored) {} } catch (IOException ignored) {}
out.close(); output.close();
} }
} }