8303604: Passing by-value structs whose size is not power of 2 doesn't work on all platforms (mainline)

Reviewed-by: mcimadamore
This commit is contained in:
Jorn Vernee 2023-03-06 15:18:39 +00:00
parent dccfe8a2ee
commit 5977f266d0
11 changed files with 703 additions and 53 deletions

View file

@ -25,6 +25,7 @@
package java.lang.invoke;
import jdk.internal.foreign.Utils;
import sun.invoke.util.Wrapper;
import java.lang.reflect.Constructor;
@ -314,7 +315,7 @@ final class VarHandles {
if (!carrier.isPrimitive() || carrier == void.class || carrier == boolean.class) {
throw new IllegalArgumentException("Invalid carrier: " + carrier.getName());
}
long size = Wrapper.forPrimitiveType(carrier).bitWidth() / 8;
long size = Utils.byteWidthOfPrimitive(carrier);
boolean be = byteOrder == ByteOrder.BIG_ENDIAN;
boolean exact = false;

View file

@ -241,4 +241,8 @@ public final class Utils {
}
return MemoryLayout.structLayout(layouts.toArray(MemoryLayout[]::new));
}
public static int byteWidthOfPrimitive(Class<?> primitive) {
return Wrapper.forPrimitiveType(primitive).bitWidth() / 8;
}
}

View file

@ -35,10 +35,15 @@ import java.lang.foreign.SegmentAllocator;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED;
import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED;
/**
* The binding operators defined in the Binding class can be combined into argument and return value processing 'recipes'.
*
@ -317,6 +322,11 @@ public interface Binding {
throw new IllegalArgumentException("Negative offset: " + offset);
}
private static void checkByteWidth(int byteWidth, Class<?> type) {
if (byteWidth < 0 || byteWidth > Utils.byteWidthOfPrimitive(type))
throw new IllegalArgumentException("Illegal byteWidth: " + byteWidth);
}
static VMStore vmStore(VMStorage storage, Class<?> type) {
checkType(type);
return new VMStore(storage, type);
@ -328,15 +338,25 @@ public interface Binding {
}
static BufferStore bufferStore(long offset, Class<?> type) {
return bufferStore(offset, type, Utils.byteWidthOfPrimitive(type));
}
static BufferStore bufferStore(long offset, Class<?> type, int byteWidth) {
checkType(type);
checkOffset(offset);
return new BufferStore(offset, type);
checkByteWidth(byteWidth, type);
return new BufferStore(offset, type, byteWidth);
}
static BufferLoad bufferLoad(long offset, Class<?> type) {
return Binding.bufferLoad(offset, type, Utils.byteWidthOfPrimitive(type));
}
static BufferLoad bufferLoad(long offset, Class<?> type, int byteWidth) {
checkType(type);
checkOffset(offset);
return new BufferLoad(offset, type);
checkByteWidth(byteWidth, type);
return new BufferLoad(offset, type, byteWidth);
}
static Copy copy(MemoryLayout layout) {
@ -433,11 +453,21 @@ public interface Binding {
return this;
}
public Binding.Builder bufferStore(long offset, Class<?> type, int byteWidth) {
bindings.add(Binding.bufferStore(offset, type, byteWidth));
return this;
}
public Binding.Builder bufferLoad(long offset, Class<?> type) {
bindings.add(Binding.bufferLoad(offset, type));
return this;
}
public Binding.Builder bufferLoad(long offset, Class<?> type, int byteWidth) {
bindings.add(Binding.bufferLoad(offset, type, byteWidth));
return this;
}
public Binding.Builder copy(MemoryLayout layout) {
bindings.add(Binding.copy(layout));
return this;
@ -532,12 +562,12 @@ public interface Binding {
}
/**
* BUFFER_STORE([offset into memory region], [type])
* BUFFER_STORE([offset into memory region], [type], [width])
* Pops a [type] from the operand stack, then pops a MemorySegment from the operand stack.
* Stores the [type] to [offset into memory region].
* Stores [width] bytes of the value contained in the [type] to [offset into memory region].
* The [type] must be one of byte, short, char, int, long, float, or double
*/
record BufferStore(long offset, Class<?> type) implements Dereference {
record BufferStore(long offset, Class<?> type, int byteWidth) implements Dereference {
@Override
public Tag tag() {
return Tag.BUFFER_STORE;
@ -555,19 +585,50 @@ public interface Binding {
public void interpret(Deque<Object> stack, BindingInterpreter.StoreFunc storeFunc,
BindingInterpreter.LoadFunc loadFunc, Context context) {
Object value = stack.pop();
MemorySegment operand = (MemorySegment) stack.pop();
MemorySegment writeAddress = operand.asSlice(offset());
SharedUtils.write(writeAddress, type(), value);
MemorySegment writeAddress = (MemorySegment) stack.pop();
if (SharedUtils.isPowerOfTwo(byteWidth())) {
// exact size match
SharedUtils.write(writeAddress, offset(), type(), value);
} else {
// non-exact match, need to do chunked load
long longValue = ((Number) value).longValue();
// byteWidth is smaller than the width of 'type', so it will always be < 8 here
int remaining = byteWidth();
int chunkOffset = 0;
do {
int chunkSize = Integer.highestOneBit(remaining); // next power of 2, in bytes
long writeOffset = offset() + SharedUtils.pickChunkOffset(chunkOffset, byteWidth(), chunkSize);
int shiftAmount = chunkOffset * Byte.SIZE;
switch (chunkSize) {
case 4 -> {
int writeChunk = (int) (((0xFFFF_FFFFL << shiftAmount) & longValue) >>> shiftAmount);
writeAddress.set(JAVA_INT_UNALIGNED, writeOffset, writeChunk);
}
case 2 -> {
short writeChunk = (short) (((0xFFFFL << shiftAmount) & longValue) >>> shiftAmount);
writeAddress.set(JAVA_SHORT_UNALIGNED, writeOffset, writeChunk);
}
case 1 -> {
byte writeChunk = (byte) (((0xFFL << shiftAmount) & longValue) >>> shiftAmount);
writeAddress.set(JAVA_BYTE, writeOffset, writeChunk);
}
default ->
throw new IllegalStateException("Unexpected chunk size for chunked write: " + chunkSize);
}
remaining -= chunkSize;
chunkOffset += chunkSize;
} while (remaining != 0);
}
}
}
/**
* BUFFER_LOAD([offset into memory region], [type])
* Pops a [type], and then a MemorySegment from the operand stack,
* and then stores [type] to [offset into memory region] of the MemorySegment.
* BUFFER_LOAD([offset into memory region], [type], [width])
* Pops a MemorySegment from the operand stack,
* and then loads [width] bytes from it at [offset into memory region], into a [type].
* The [type] must be one of byte, short, char, int, long, float, or double
*/
record BufferLoad(long offset, Class<?> type) implements Dereference {
record BufferLoad(long offset, Class<?> type, int byteWidth) implements Dereference {
@Override
public Tag tag() {
return Tag.BUFFER_LOAD;
@ -584,9 +645,39 @@ public interface Binding {
@Override
public void interpret(Deque<Object> stack, BindingInterpreter.StoreFunc storeFunc,
BindingInterpreter.LoadFunc loadFunc, Context context) {
MemorySegment operand = (MemorySegment) stack.pop();
MemorySegment readAddress = operand.asSlice(offset());
stack.push(SharedUtils.read(readAddress, type()));
MemorySegment readAddress = (MemorySegment) stack.pop();
if (SharedUtils.isPowerOfTwo(byteWidth())) {
// exact size match
stack.push(SharedUtils.read(readAddress, offset(), type()));
} else {
// non-exact match, need to do chunked load
long result = 0;
// byteWidth is smaller than the width of 'type', so it will always be < 8 here
int remaining = byteWidth();
int chunkOffset = 0;
do {
int chunkSize = Integer.highestOneBit(remaining); // next power of 2
long readOffset = offset() + SharedUtils.pickChunkOffset(chunkOffset, byteWidth(), chunkSize);
long readChunk = switch (chunkSize) {
case 4 -> Integer.toUnsignedLong(readAddress.get(JAVA_INT_UNALIGNED, readOffset));
case 2 -> Short.toUnsignedLong(readAddress.get(JAVA_SHORT_UNALIGNED, readOffset));
case 1 -> Byte.toUnsignedLong(readAddress.get(JAVA_BYTE, readOffset));
default ->
throw new IllegalStateException("Unexpected chunk size for chunked write: " + chunkSize);
};
result |= readChunk << (chunkOffset * Byte.SIZE);
remaining -= chunkSize;
chunkOffset += chunkSize;
} while (remaining != 0);
if (type() == int.class) { // 3 byte write
stack.push((int) result);
} else if (type() == long.class) { // 5, 6, 7 byte write
stack.push(result);
} else {
throw new IllegalStateException("Unexpected type for chunked load: " + type());
}
}
}
}

View file

@ -95,6 +95,9 @@ public class BindingSpecializer {
private static final String CLASS_DATA_DESC = methodType(Object.class, MethodHandles.Lookup.class, String.class, Class.class).descriptorString();
private static final String RELEASE0_DESC = VOID_DESC;
private static final String ACQUIRE0_DESC = VOID_DESC;
private static final String INTEGER_TO_UNSIGNED_LONG_DESC = MethodType.methodType(long.class, int.class).descriptorString();
private static final String SHORT_TO_UNSIGNED_LONG_DESC = MethodType.methodType(long.class, short.class).descriptorString();
private static final String BYTE_TO_UNSIGNED_LONG_DESC = MethodType.methodType(long.class, byte.class).descriptorString();
private static final Handle BSM_CLASS_DATA = new Handle(
H_INVOKESTATIC,
@ -602,17 +605,82 @@ public class BindingSpecializer {
private void emitBufferStore(Binding.BufferStore bufferStore) {
Class<?> storeType = bufferStore.type();
long offset = bufferStore.offset();
int byteWidth = bufferStore.byteWidth();
popType(storeType);
popType(MemorySegment.class);
int valueIdx = newLocal(storeType);
emitStore(storeType, valueIdx);
Class<?> valueLayoutType = emitLoadLayoutConstant(storeType);
emitConst(offset);
emitLoad(storeType, valueIdx);
String descriptor = methodType(void.class, valueLayoutType, long.class, storeType).descriptorString();
emitInvokeInterface(MemorySegment.class, "set", descriptor);
if (SharedUtils.isPowerOfTwo(byteWidth)) {
int valueIdx = newLocal(storeType);
emitStore(storeType, valueIdx);
Class<?> valueLayoutType = emitLoadLayoutConstant(storeType);
emitConst(offset);
emitLoad(storeType, valueIdx);
String descriptor = methodType(void.class, valueLayoutType, long.class, storeType).descriptorString();
emitInvokeInterface(MemorySegment.class, "set", descriptor);
} else {
// long longValue = ((Number) value).longValue();
if (storeType == int.class) {
mv.visitInsn(I2L);
} else {
assert storeType == long.class; // chunking only for int and long
}
int longValueIdx = newLocal(long.class);
emitStore(long.class, longValueIdx);
int writeAddrIdx = newLocal(MemorySegment.class);
emitStore(MemorySegment.class, writeAddrIdx);
int remaining = byteWidth;
int chunkOffset = 0;
do {
int chunkSize = Integer.highestOneBit(remaining); // next power of 2, in bytes
Class<?> chunkStoreType;
long mask;
switch (chunkSize) {
case 4 -> {
chunkStoreType = int.class;
mask = 0xFFFF_FFFFL;
}
case 2 -> {
chunkStoreType = short.class;
mask = 0xFFFFL;
}
case 1 -> {
chunkStoreType = byte.class;
mask = 0xFFL;
}
default ->
throw new IllegalStateException("Unexpected chunk size for chunked write: " + chunkSize);
}
//int writeChunk = (int) (((0xFFFF_FFFFL << shiftAmount) & longValue) >>> shiftAmount);
int shiftAmount = chunkOffset * Byte.SIZE;
mask = mask << shiftAmount;
emitLoad(long.class, longValueIdx);
emitConst(mask);
mv.visitInsn(LAND);
if (shiftAmount != 0) {
emitConst(shiftAmount);
mv.visitInsn(LUSHR);
}
mv.visitInsn(L2I);
int chunkIdx = newLocal(chunkStoreType);
emitStore(chunkStoreType, chunkIdx);
// chunk done, now write it
//writeAddress.set(JAVA_SHORT_UNALIGNED, offset, writeChunk);
emitLoad(MemorySegment.class, writeAddrIdx);
Class<?> valueLayoutType = emitLoadLayoutConstant(chunkStoreType);
long writeOffset = offset + SharedUtils.pickChunkOffset(chunkOffset, byteWidth, chunkSize);
emitConst(writeOffset);
emitLoad(chunkStoreType, chunkIdx);
String descriptor = methodType(void.class, valueLayoutType, long.class, chunkStoreType).descriptorString();
emitInvokeInterface(MemorySegment.class, "set", descriptor);
remaining -= chunkSize;
chunkOffset += chunkSize;
} while (remaining != 0);
}
}
// VM_STORE and VM_LOAD are emulated, which is different for down/upcalls
@ -708,13 +776,82 @@ public class BindingSpecializer {
private void emitBufferLoad(Binding.BufferLoad bufferLoad) {
Class<?> loadType = bufferLoad.type();
long offset = bufferLoad.offset();
int byteWidth = bufferLoad.byteWidth();
popType(MemorySegment.class);
Class<?> valueLayoutType = emitLoadLayoutConstant(loadType);
emitConst(offset);
String descriptor = methodType(loadType, valueLayoutType, long.class).descriptorString();
emitInvokeInterface(MemorySegment.class, "get", descriptor);
if (SharedUtils.isPowerOfTwo(byteWidth)) {
Class<?> valueLayoutType = emitLoadLayoutConstant(loadType);
emitConst(offset);
String descriptor = methodType(loadType, valueLayoutType, long.class).descriptorString();
emitInvokeInterface(MemorySegment.class, "get", descriptor);
} else {
// chunked
int readAddrIdx = newLocal(MemorySegment.class);
emitStore(MemorySegment.class, readAddrIdx);
emitConstZero(long.class); // result
int resultIdx = newLocal(long.class);
emitStore(long.class, resultIdx);
int remaining = byteWidth;
int chunkOffset = 0;
do {
int chunkSize = Integer.highestOneBit(remaining); // next power of 2
Class<?> chunkType;
Class<?> toULongHolder;
String toULongDescriptor;
switch (chunkSize) {
case 4 -> {
chunkType = int.class;
toULongHolder = Integer.class;
toULongDescriptor = INTEGER_TO_UNSIGNED_LONG_DESC;
}
case 2 -> {
chunkType = short.class;
toULongHolder = Short.class;
toULongDescriptor = SHORT_TO_UNSIGNED_LONG_DESC;
}
case 1 -> {
chunkType = byte.class;
toULongHolder = Byte.class;
toULongDescriptor = BYTE_TO_UNSIGNED_LONG_DESC;
}
default ->
throw new IllegalStateException("Unexpected chunk size for chunked write: " + chunkSize);
}
// read from segment
emitLoad(MemorySegment.class, readAddrIdx);
Class<?> valueLayoutType = emitLoadLayoutConstant(chunkType);
String descriptor = methodType(chunkType, valueLayoutType, long.class).descriptorString();
long readOffset = offset + SharedUtils.pickChunkOffset(chunkOffset, byteWidth, chunkSize);
emitConst(readOffset);
emitInvokeInterface(MemorySegment.class, "get", descriptor);
emitInvokeStatic(toULongHolder, "toUnsignedLong", toULongDescriptor);
// shift to right offset
int shiftAmount = chunkOffset * Byte.SIZE;
if (shiftAmount != 0) {
emitConst(shiftAmount);
mv.visitInsn(LSHL);
}
// add to result
emitLoad(long.class, resultIdx);
mv.visitInsn(LOR);
emitStore(long.class, resultIdx);
remaining -= chunkSize;
chunkOffset += chunkSize;
} while (remaining != 0);
emitLoad(long.class, resultIdx);
if (loadType == int.class) {
mv.visitInsn(L2I);
} else {
assert loadType == long.class; // should not have chunking for other types
}
}
pushType(loadType);
}

View file

@ -186,7 +186,7 @@ public class DowncallLinker {
int retBufReadOffset = 0;
@Override
public Object load(VMStorage storage, Class<?> type) {
Object result1 = SharedUtils.read(finalReturnBuffer.asSlice(retBufReadOffset), type);
Object result1 = SharedUtils.read(finalReturnBuffer, retBufReadOffset, type);
retBufReadOffset += abi.arch.typeSize(storage.type());
return result1;
}

View file

@ -50,6 +50,7 @@ import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.VarHandle;
import java.lang.ref.Reference;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Map;
@ -336,6 +337,16 @@ public final class SharedUtils {
}
}
public static boolean isPowerOfTwo(int width) {
return Integer.bitCount(width) == 1;
}
static long pickChunkOffset(long chunkOffset, long byteWidth, int chunkWidth) {
return ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN
? byteWidth - chunkWidth - chunkOffset
: chunkOffset;
}
public static NoSuchElementException newVaListNSEE(MemoryLayout layout) {
return new NoSuchElementException("No such element: " + layout);
}
@ -431,45 +442,45 @@ public final class SharedUtils {
}
}
static void write(MemorySegment ptr, Class<?> type, Object o) {
static void write(MemorySegment ptr, long offset, Class<?> type, Object o) {
if (type == long.class) {
ptr.set(JAVA_LONG_UNALIGNED, 0, (long) o);
ptr.set(JAVA_LONG_UNALIGNED, offset, (long) o);
} else if (type == int.class) {
ptr.set(JAVA_INT_UNALIGNED, 0, (int) o);
ptr.set(JAVA_INT_UNALIGNED, offset, (int) o);
} else if (type == short.class) {
ptr.set(JAVA_SHORT_UNALIGNED, 0, (short) o);
ptr.set(JAVA_SHORT_UNALIGNED, offset, (short) o);
} else if (type == char.class) {
ptr.set(JAVA_CHAR_UNALIGNED, 0, (char) o);
ptr.set(JAVA_CHAR_UNALIGNED, offset, (char) o);
} else if (type == byte.class) {
ptr.set(JAVA_BYTE, 0, (byte) o);
ptr.set(JAVA_BYTE, offset, (byte) o);
} else if (type == float.class) {
ptr.set(JAVA_FLOAT_UNALIGNED, 0, (float) o);
ptr.set(JAVA_FLOAT_UNALIGNED, offset, (float) o);
} else if (type == double.class) {
ptr.set(JAVA_DOUBLE_UNALIGNED, 0, (double) o);
ptr.set(JAVA_DOUBLE_UNALIGNED, offset, (double) o);
} else if (type == boolean.class) {
ptr.set(JAVA_BOOLEAN, 0, (boolean) o);
ptr.set(JAVA_BOOLEAN, offset, (boolean) o);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + type);
}
}
static Object read(MemorySegment ptr, Class<?> type) {
static Object read(MemorySegment ptr, long offset, Class<?> type) {
if (type == long.class) {
return ptr.get(JAVA_LONG_UNALIGNED, 0);
return ptr.get(JAVA_LONG_UNALIGNED, offset);
} else if (type == int.class) {
return ptr.get(JAVA_INT_UNALIGNED, 0);
return ptr.get(JAVA_INT_UNALIGNED, offset);
} else if (type == short.class) {
return ptr.get(JAVA_SHORT_UNALIGNED, 0);
return ptr.get(JAVA_SHORT_UNALIGNED, offset);
} else if (type == char.class) {
return ptr.get(JAVA_CHAR_UNALIGNED, 0);
return ptr.get(JAVA_CHAR_UNALIGNED, offset);
} else if (type == byte.class) {
return ptr.get(JAVA_BYTE, 0);
return ptr.get(JAVA_BYTE, offset);
} else if (type == float.class) {
return ptr.get(JAVA_FLOAT_UNALIGNED, 0);
return ptr.get(JAVA_FLOAT_UNALIGNED, offset);
} else if (type == double.class) {
return ptr.get(JAVA_DOUBLE_UNALIGNED, 0);
return ptr.get(JAVA_DOUBLE_UNALIGNED, offset);
} else if (type == boolean.class) {
return ptr.get(JAVA_BOOLEAN, 0);
return ptr.get(JAVA_BOOLEAN, offset);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + type);
}

View file

@ -236,7 +236,7 @@ public abstract class CallArranger {
return carrier;
}
record StructStorage(long offset, Class<?> carrier, VMStorage storage) {}
record StructStorage(long offset, Class<?> carrier, int byteWidth, VMStorage storage) {}
/*
In the simplest case structs are copied in chunks. i.e. the fields don't matter, just the size.
@ -305,12 +305,14 @@ public abstract class CallArranger {
long offset = 0;
for (int i = 0; i < structStorages.length; i++) {
ValueLayout copyLayout;
long copySize;
if (isFieldWise) {
// We should only get here for HFAs, which can't have padding
copyLayout = (ValueLayout) scalarLayouts.get(i);
copySize = Utils.byteWidthOfPrimitive(copyLayout.carrier());
} else {
// chunk-wise copy
long copySize = Math.min(layout.byteSize() - offset, MAX_COPY_SIZE);
copySize = Math.min(layout.byteSize() - offset, MAX_COPY_SIZE);
boolean useFloat = false; // never use float for chunk-wise copies
copyLayout = SharedUtils.primitiveLayoutForSize(copySize, useFloat);
}
@ -322,7 +324,7 @@ public abstract class CallArranger {
// Don't use floats on the stack
carrier = adjustCarrierForStack(carrier);
}
structStorages[i] = new StructStorage(offset, carrier, storage);
structStorages[i] = new StructStorage(offset, carrier, (int) copySize, storage);
offset += copyLayout.byteSize();
}
@ -421,7 +423,7 @@ public abstract class CallArranger {
if (i < structStorages.length - 1) {
bindings.dup();
}
bindings.bufferLoad(structStorage.offset(), structStorage.carrier())
bindings.bufferLoad(structStorage.offset(), structStorage.carrier(), structStorage.byteWidth())
.vmStore(structStorage.storage(), structStorage.carrier());
}
}
@ -483,7 +485,7 @@ public abstract class CallArranger {
for (StorageCalculator.StructStorage structStorage : structStorages) {
bindings.dup();
bindings.vmLoad(structStorage.storage(), structStorage.carrier())
.bufferStore(structStorage.offset(), structStorage.carrier());
.bufferStore(structStorage.offset(), structStorage.carrier(), structStorage.byteWidth());
}
}
case STRUCT_REFERENCE -> {

View file

@ -263,7 +263,7 @@ public class CallArranger {
}
boolean useFloat = storage.type() == StorageType.VECTOR;
Class<?> type = SharedUtils.primitiveCarrierForSize(copy, useFloat);
bindings.bufferLoad(offset, type)
bindings.bufferLoad(offset, type, (int) copy)
.vmStore(storage, type);
offset += copy;
}
@ -311,7 +311,7 @@ public class CallArranger {
boolean useFloat = storage.type() == StorageType.VECTOR;
Class<?> type = SharedUtils.primitiveCarrierForSize(copy, useFloat);
bindings.vmLoad(storage, type)
.bufferStore(offset, type);
.bufferStore(offset, type, (int) copy);
offset += copy;
}
}