8181386: CipherSpi ByteBuffer to byte array conversion fails for certain data overlap conditions

Detect potential buffer overlap and use extra buffer if necessary

Reviewed-by: xuelei
This commit is contained in:
Valerie Peng 2019-07-10 18:43:45 +00:00
parent 019b9891d7
commit 29215b987b
2 changed files with 252 additions and 52 deletions

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 1997, 2018, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1997, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@ -761,78 +761,87 @@ public abstract class CipherSpi {
+ " bytes of space in output buffer");
}
// detecting input and output buffer overlap may be tricky
// we can only write directly into output buffer when we
// are 100% sure it's safe to do so
boolean a1 = input.hasArray();
boolean a2 = output.hasArray();
int total = 0;
byte[] inArray, outArray;
if (a2) { // output has an accessible byte[]
outArray = output.array();
int outPos = output.position();
int outOfs = output.arrayOffset() + outPos;
if (a1) { // input also has an accessible byte[]
inArray = input.array();
int inOfs = input.arrayOffset() + inPos;
if (a1) { // input has an accessible byte[]
byte[] inArray = input.array();
int inOfs = input.arrayOffset() + inPos;
if (a2) { // output has an accessible byte[]
byte[] outArray = output.array();
int outPos = output.position();
int outOfs = output.arrayOffset() + outPos;
// check array address and offsets and use temp output buffer
// if output offset is larger than input offset and
// falls within the range of input data
boolean useTempOut = false;
if (inArray == outArray &&
((inOfs < outOfs) && (outOfs < inOfs + inLen))) {
useTempOut = true;
outArray = new byte[outLenNeeded];
outOfs = 0;
}
if (isUpdate) {
total = engineUpdate(inArray, inOfs, inLen, outArray, outOfs);
} else {
total = engineDoFinal(inArray, inOfs, inLen, outArray, outOfs);
}
if (useTempOut) {
output.put(outArray, outOfs, total);
} else {
// adjust output position manually
output.position(outPos + total);
}
// adjust input position manually
input.position(inLimit);
} else { // input does not have accessible byte[]
inArray = new byte[getTempArraySize(inLen)];
do {
int chunk = Math.min(inLen, inArray.length);
if (chunk > 0) {
input.get(inArray, 0, chunk);
}
int n;
if (isUpdate || (inLen > chunk)) {
n = engineUpdate(inArray, 0, chunk, outArray, outOfs);
} else {
n = engineDoFinal(inArray, 0, chunk, outArray, outOfs);
}
total += n;
outOfs += n;
inLen -= chunk;
} while (inLen > 0);
}
output.position(outPos + total);
} else { // output does not have an accessible byte[]
if (a1) { // but input has an accessible byte[]
inArray = input.array();
int inOfs = input.arrayOffset() + inPos;
} else { // output does not have an accessible byte[]
byte[] outArray = null;
if (isUpdate) {
outArray = engineUpdate(inArray, inOfs, inLen);
} else {
outArray = engineDoFinal(inArray, inOfs, inLen);
}
input.position(inLimit);
if (outArray != null && outArray.length != 0) {
output.put(outArray);
total = outArray.length;
}
} else { // input also does not have an accessible byte[]
inArray = new byte[getTempArraySize(inLen)];
do {
int chunk = Math.min(inLen, inArray.length);
if (chunk > 0) {
input.get(inArray, 0, chunk);
}
int n;
if (isUpdate || (inLen > chunk)) {
outArray = engineUpdate(inArray, 0, chunk);
} else {
outArray = engineDoFinal(inArray, 0, chunk);
}
if (outArray != null && outArray.length != 0) {
output.put(outArray);
total += outArray.length;
}
inLen -= chunk;
} while (inLen > 0);
// adjust input position manually
input.position(inLimit);
}
} else { // input does not have an accessible byte[]
// have to assume the worst, since we have no way of determine
// if input and output overlaps or not
byte[] tempOut = new byte[outLenNeeded];
int outOfs = 0;
byte[] tempIn = new byte[getTempArraySize(inLen)];
do {
int chunk = Math.min(inLen, tempIn.length);
if (chunk > 0) {
input.get(tempIn, 0, chunk);
}
int n;
if (isUpdate || (inLen > chunk)) {
n = engineUpdate(tempIn, 0, chunk, tempOut, outOfs);
} else {
n = engineDoFinal(tempIn, 0, chunk, tempOut, outOfs);
}
outOfs += n;
total += n;
inLen -= chunk;
} while (inLen > 0);
if (total > 0) {
output.put(tempOut, 0, total);
}
}
return total;
}