8336768: Allow captureCallState and critical linker options to be combined

Reviewed-by: mcimadamore
This commit is contained in:
Jorn Vernee 2024-12-03 12:28:17 +00:00
parent 63af2f42b7
commit 8cad0431ff
14 changed files with 178 additions and 86 deletions

View file

@ -852,8 +852,6 @@ public sealed interface Linker permits AbstractLinker {
* // use errno * // use errno
* } * }
* } * }
* <p>
* This linker option can not be combined with {@link #critical}.
* *
* @param capturedState the names of the values to save * @param capturedState the names of the values to save
* @throws IllegalArgumentException if at least one of the provided * @throws IllegalArgumentException if at least one of the provided

View file

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
@ -195,6 +195,10 @@ public class CallingSequence {
return !linkerOptions.isCritical(); return !linkerOptions.isCritical();
} }
public boolean usingAddressPairs() {
return linkerOptions.allowsHeapAccess();
}
public int numLeadingParams() { public int numLeadingParams() {
return 2 + (linkerOptions.hasCapturedCallState() ? 1 : 0); // 2 for addr, allocator return 2 + (linkerOptions.hasCapturedCallState() ? 1 : 0); // 2 for addr, allocator
} }

View file

@ -108,9 +108,18 @@ public class CallingSequenceBuilder {
MethodType calleeMethodType; MethodType calleeMethodType;
if (!forUpcall) { if (!forUpcall) {
if (linkerOptions.hasCapturedCallState()) { if (linkerOptions.hasCapturedCallState()) {
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of( if (linkerOptions.allowsHeapAccess()) {
Binding.unboxAddress(), addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.vmStore(abi.capturedStateStorage(), long.class))); Binding.dup(),
Binding.segmentBase(),
Binding.vmStore(abi.capturedStateStorage(), Object.class),
Binding.segmentOffsetAllowHeap(),
Binding.vmStore(null, long.class)));
} else {
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.unboxAddress(),
Binding.vmStore(abi.capturedStateStorage(), long.class)));
}
} }
addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of( addArgumentBinding(0, MemorySegment.class, ValueLayout.ADDRESS, List.of(
Binding.unboxAddress(), Binding.unboxAddress(),

View file

@ -84,7 +84,8 @@ public class DowncallLinker {
leafType, leafType,
callingSequence.needsReturnBuffer(), callingSequence.needsReturnBuffer(),
callingSequence.capturedStateMask(), callingSequence.capturedStateMask(),
callingSequence.needsTransition() callingSequence.needsTransition(),
callingSequence.usingAddressPairs()
); );
MethodHandle handle = JLIA.nativeMethodHandle(nep); MethodHandle handle = JLIA.nativeMethodHandle(nep);

View file

@ -63,11 +63,7 @@ public class LinkerOptions {
optionMap.put(option.getClass(), opImpl); optionMap.put(option.getClass(), opImpl);
} }
LinkerOptions linkerOptions = new LinkerOptions(optionMap); return new LinkerOptions(optionMap);
if (linkerOptions.hasCapturedCallState() && linkerOptions.isCritical()) {
throw new IllegalArgumentException("Incompatible linker options: captureCallState, critical");
}
return linkerOptions;
} }
public static LinkerOptions empty() { public static LinkerOptions empty() {

View file

@ -60,11 +60,12 @@ public class NativeEntryPoint {
MethodType methodType, MethodType methodType,
boolean needsReturnBuffer, boolean needsReturnBuffer,
int capturedStateMask, int capturedStateMask,
boolean needsTransition) { boolean needsTransition,
boolean usingAddressPairs) {
if (returnMoves.length > 1 != needsReturnBuffer) { if (returnMoves.length > 1 != needsReturnBuffer) {
throw new AssertionError("Multiple register return, but needsReturnBuffer was false"); throw new AssertionError("Multiple register return, but needsReturnBuffer was false");
} }
checkType(methodType, needsReturnBuffer, capturedStateMask); checkMethodType(methodType, needsReturnBuffer, capturedStateMask, usingAddressPairs);
CacheKey key = new CacheKey(methodType, abi, Arrays.asList(argMoves), Arrays.asList(returnMoves), CacheKey key = new CacheKey(methodType, abi, Arrays.asList(argMoves), Arrays.asList(returnMoves),
needsReturnBuffer, capturedStateMask, needsTransition); needsReturnBuffer, capturedStateMask, needsTransition);
@ -80,14 +81,26 @@ public class NativeEntryPoint {
}); });
} }
private static void checkType(MethodType methodType, boolean needsReturnBuffer, int savedValueMask) { private static void checkMethodType(MethodType methodType, boolean needsReturnBuffer, int savedValueMask,
if (methodType.parameterType(0) != long.class) { boolean usingAddressPairs) {
throw new AssertionError("Address expected as first param: " + methodType); int checkIdx = 0;
checkParamType(methodType, checkIdx++, long.class, "Function address");
if (needsReturnBuffer) {
checkParamType(methodType, checkIdx++, long.class, "Return buffer address");
} }
int checkIdx = 1; if (savedValueMask != 0) { // capturing call state
if ((needsReturnBuffer && methodType.parameterType(checkIdx++) != long.class) if (usingAddressPairs) {
|| (savedValueMask != 0 && methodType.parameterType(checkIdx) != long.class)) { checkParamType(methodType, checkIdx++, Object.class, "Capture state heap base");
throw new AssertionError("return buffer and/or preserved value address expected: " + methodType); checkParamType(methodType, checkIdx, long.class, "Capture state offset");
} else {
checkParamType(methodType, checkIdx, long.class, "Capture state address");
}
}
}
private static void checkParamType(MethodType methodType, int checkIdx, Class<?> expectedType, String name) {
if (methodType.parameterType(checkIdx) != expectedType) {
throw new AssertionError(name + " expected at index " + checkIdx + ": " + methodType);
} }
} }

View file

@ -163,8 +163,14 @@ public final class FallbackLinker extends AbstractLinker {
acquiredSessions.add(targetImpl); acquiredSessions.add(targetImpl);
MemorySegment capturedState = null; MemorySegment capturedState = null;
Object captureStateHeapBase = null;
if (invData.capturedStateMask() != 0) { if (invData.capturedStateMask() != 0) {
capturedState = SharedUtils.checkCaptureSegment((MemorySegment) args[argStart++]); capturedState = SharedUtils.checkCaptureSegment((MemorySegment) args[argStart++]);
if (!invData.allowsHeapAccess) {
SharedUtils.checkNative(capturedState);
} else {
captureStateHeapBase = capturedState.heapBase().orElse(null);
}
MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl(); MemorySessionImpl capturedStateImpl = ((AbstractMemorySegmentImpl) capturedState).sessionImpl();
capturedStateImpl.acquire0(); capturedStateImpl.acquire0();
acquiredSessions.add(capturedStateImpl); acquiredSessions.add(capturedStateImpl);
@ -199,7 +205,8 @@ public final class FallbackLinker extends AbstractLinker {
retSeg = (invData.returnLayout() instanceof GroupLayout ? returnAllocator : arena).allocate(invData.returnLayout); retSeg = (invData.returnLayout() instanceof GroupLayout ? returnAllocator : arena).allocate(invData.returnLayout);
} }
LibFallback.doDowncall(invData.cif, target, retSeg, argPtrs, capturedState, invData.capturedStateMask(), LibFallback.doDowncall(invData.cif, target, retSeg, argPtrs,
captureStateHeapBase, capturedState, invData.capturedStateMask(),
heapBases, args.length); heapBases, args.length);
Reference.reachabilityFence(invData.cif()); Reference.reachabilityFence(invData.cif());

View file

@ -90,10 +90,11 @@ final class LibFallback {
* @see jdk.internal.foreign.abi.CapturableState * @see jdk.internal.foreign.abi.CapturableState
*/ */
static void doDowncall(MemorySegment cif, MemorySegment target, MemorySegment retPtr, MemorySegment argPtrs, static void doDowncall(MemorySegment cif, MemorySegment target, MemorySegment retPtr, MemorySegment argPtrs,
MemorySegment capturedState, int capturedStateMask, Object captureStateHeapBase, MemorySegment capturedState, int capturedStateMask,
Object[] heapBases, int numArgs) { Object[] heapBases, int numArgs) {
doDowncall(cif.address(), target.address(), doDowncall(cif.address(), target.address(),
retPtr == null ? 0 : retPtr.address(), argPtrs.address(), retPtr == null ? 0 : retPtr.address(), argPtrs.address(),
captureStateHeapBase,
capturedState == null ? 0 : capturedState.address(), capturedStateMask, capturedState == null ? 0 : capturedState.address(), capturedStateMask,
heapBases, numArgs); heapBases, numArgs);
} }
@ -212,7 +213,7 @@ final class LibFallback {
private static native int createClosure(long cif, Object userData, long[] ptrs); private static native int createClosure(long cif, Object userData, long[] ptrs);
private static native void freeClosure(long closureAddress, long globalTarget); private static native void freeClosure(long closureAddress, long globalTarget);
private static native void doDowncall(long cif, long fn, long rvalue, long avalues, private static native void doDowncall(long cif, long fn, long rvalue, long avalues,
long capturedState, int capturedStateMask, Object captureStateHeapBase, long capturedState, int capturedStateMask,
Object[] heapBases, int numArgs); Object[] heapBases, int numArgs);
private static native int ffi_prep_cif(long cif, int abi, int nargs, long rtype, long atypes); private static native int ffi_prep_cif(long cif, int abi, int nargs, long rtype, long atypes);

View file

@ -112,12 +112,16 @@ static void do_capture_state(int32_t* value_ptr, int captured_state_mask) {
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_jdk_internal_foreign_abi_fallback_LibFallback_doDowncall(JNIEnv* env, jclass cls, jlong cif, jlong fn, jlong rvalue, Java_jdk_internal_foreign_abi_fallback_LibFallback_doDowncall(JNIEnv* env, jclass cls, jlong cif, jlong fn, jlong rvalue,
jlong avalues, jlong jcaptured_state, jint captured_state_mask, jlong avalues,
jarray capture_state_heap_base, jlong captured_state_offset,
jint captured_state_mask,
jarray heapBases, jint numArgs) { jarray heapBases, jint numArgs) {
void** carrays; void** carrays;
int capture_state_hb_offset = numArgs;
int32_t* captured_state_addr = jlong_to_ptr(captured_state_offset);
if (heapBases != NULL) { if (heapBases != NULL) {
void** aptrs = jlong_to_ptr(avalues); void** aptrs = jlong_to_ptr(avalues);
carrays = malloc(sizeof(void*) * numArgs); carrays = malloc(sizeof(void*) * (numArgs + 1));
for (int i = 0; i < numArgs; i++) { for (int i = 0; i < numArgs; i++) {
jarray hb = (jarray) (*env)->GetObjectArrayElement(env, heapBases, i); jarray hb = (jarray) (*env)->GetObjectArrayElement(env, heapBases, i);
if (hb != NULL) { if (hb != NULL) {
@ -130,10 +134,20 @@ Java_jdk_internal_foreign_abi_fallback_LibFallback_doDowncall(JNIEnv* env, jclas
*((void**)aptrs[i]) = arrayPtr + offset; *((void**)aptrs[i]) = arrayPtr + offset;
} }
} }
if (capture_state_heap_base != NULL) {
jboolean isCopy;
jbyte* arrayPtr = (*env)->GetPrimitiveArrayCritical(env, capture_state_heap_base, &isCopy);
carrays[capture_state_hb_offset] = arrayPtr;
captured_state_addr = (int32_t*) (arrayPtr + captured_state_offset);
}
} }
ffi_call(jlong_to_ptr(cif), jlong_to_ptr(fn), jlong_to_ptr(rvalue), jlong_to_ptr(avalues)); ffi_call(jlong_to_ptr(cif), jlong_to_ptr(fn), jlong_to_ptr(rvalue), jlong_to_ptr(avalues));
if (captured_state_mask != 0) {
do_capture_state(captured_state_addr, captured_state_mask);
}
if (heapBases != NULL) { if (heapBases != NULL) {
for (int i = 0; i < numArgs; i++) { for (int i = 0; i < numArgs; i++) {
jarray hb = (jarray) (*env)->GetObjectArrayElement(env, heapBases, i); jarray hb = (jarray) (*env)->GetObjectArrayElement(env, heapBases, i);
@ -141,13 +155,11 @@ Java_jdk_internal_foreign_abi_fallback_LibFallback_doDowncall(JNIEnv* env, jclas
(*env)->ReleasePrimitiveArrayCritical(env, hb, carrays[i], JNI_COMMIT); (*env)->ReleasePrimitiveArrayCritical(env, hb, carrays[i], JNI_COMMIT);
} }
} }
if (capture_state_heap_base != NULL) {
(*env)->ReleasePrimitiveArrayCritical(env, capture_state_heap_base, carrays[capture_state_hb_offset], JNI_COMMIT);
}
free(carrays); free(carrays);
} }
if (captured_state_mask != 0) {
int32_t* captured_state = jlong_to_ptr(jcaptured_state);
do_capture_state(captured_state, captured_state_mask);
}
} }
static void do_upcall(ffi_cif* cif, void* ret, void** args, void* user_data) { static void do_upcall(ffi_cif* cif, void* ret, void** args, void* user_data) {

View file

@ -192,11 +192,6 @@ public class TestIllegalLink extends NativeTestHelper {
NO_OPTIONS, NO_OPTIONS,
"has unexpected size" "has unexpected size"
}, },
{
FunctionDescriptor.ofVoid(),
new Linker.Option[]{Linker.Option.critical(false), Linker.Option.captureCallState("errno")},
"Incompatible linker options: captureCallState, critical"
},
})); }));
for (ValueLayout illegalLayout : List.of(C_CHAR, ValueLayout.JAVA_CHAR, C_BOOL, C_SHORT, C_FLOAT)) { for (ValueLayout illegalLayout : List.of(C_CHAR, ValueLayout.JAVA_CHAR, C_BOOL, C_SHORT, C_FLOAT)) {

View file

@ -61,12 +61,18 @@ public class TestCaptureCallState extends NativeTestHelper {
} }
} }
private record SaveValuesCase(String nativeTarget, FunctionDescriptor nativeDesc, String threadLocalName, Consumer<Object> resultCheck) {} private record SaveValuesCase(String nativeTarget, FunctionDescriptor nativeDesc, String threadLocalName,
Consumer<Object> resultCheck, boolean critical) {}
@Test(dataProvider = "cases") @Test(dataProvider = "cases")
public void testSavedThreadLocal(SaveValuesCase testCase) throws Throwable { public void testSavedThreadLocal(SaveValuesCase testCase) throws Throwable {
Linker.Option stl = Linker.Option.captureCallState(testCase.threadLocalName()); List<Linker.Option> options = new ArrayList<>();
MethodHandle handle = downcallHandle(testCase.nativeTarget(), testCase.nativeDesc(), stl); options.add(Linker.Option.captureCallState(testCase.threadLocalName()));
if (testCase.critical()) {
options.add(Linker.Option.critical(false));
}
MethodHandle handle = downcallHandle(testCase.nativeTarget(), testCase.nativeDesc(),
options.toArray(Linker.Option[]::new));
StructLayout capturedStateLayout = Linker.Option.captureStateLayout(); StructLayout capturedStateLayout = Linker.Option.captureStateLayout();
VarHandle errnoHandle = capturedStateLayout.varHandle(groupElement(testCase.threadLocalName())); VarHandle errnoHandle = capturedStateLayout.varHandle(groupElement(testCase.threadLocalName()));
@ -86,9 +92,14 @@ public class TestCaptureCallState extends NativeTestHelper {
@Test(dataProvider = "invalidCaptureSegmentCases") @Test(dataProvider = "invalidCaptureSegmentCases")
public void testInvalidCaptureSegment(MemorySegment captureSegment, public void testInvalidCaptureSegment(MemorySegment captureSegment,
Class<?> expectedExceptionType, String expectedExceptionMessage) { Class<?> expectedExceptionType, String expectedExceptionMessage,
Linker.Option stl = Linker.Option.captureCallState("errno"); Linker.Option[] extraOptions) {
MethodHandle handle = downcallHandle("set_errno_V", FunctionDescriptor.ofVoid(C_INT), stl); List<Linker.Option> options = new ArrayList<>();
options.add(Linker.Option.captureCallState("errno"));
for (Linker.Option extra : extraOptions) {
options.add(extra);
}
MethodHandle handle = downcallHandle("set_errno_V", FunctionDescriptor.ofVoid(C_INT), options.toArray(Linker.Option[]::new));
try { try {
int testValue = 42; int testValue = 42;
@ -103,32 +114,39 @@ public class TestCaptureCallState extends NativeTestHelper {
public static Object[][] cases() { public static Object[][] cases() {
List<SaveValuesCase> cases = new ArrayList<>(); List<SaveValuesCase> cases = new ArrayList<>();
cases.add(new SaveValuesCase("set_errno_V", FunctionDescriptor.ofVoid(JAVA_INT), "errno", o -> {})); for (boolean critical : new boolean[]{ true, false }) {
cases.add(new SaveValuesCase("set_errno_I", FunctionDescriptor.of(JAVA_INT, JAVA_INT), "errno", o -> assertEquals((int) o, 42))); cases.add(new SaveValuesCase("set_errno_V", FunctionDescriptor.ofVoid(JAVA_INT),
cases.add(new SaveValuesCase("set_errno_D", FunctionDescriptor.of(JAVA_DOUBLE, JAVA_INT), "errno", o -> assertEquals((double) o, 42.0))); "errno", o -> {}, critical));
cases.add(new SaveValuesCase("set_errno_I", FunctionDescriptor.of(JAVA_INT, JAVA_INT),
"errno", o -> assertEquals((int) o, 42), critical));
cases.add(new SaveValuesCase("set_errno_D", FunctionDescriptor.of(JAVA_DOUBLE, JAVA_INT),
"errno", o -> assertEquals((double) o, 42.0), critical));
cases.add(structCase("SL", Map.of(JAVA_LONG.withName("x"), 42L))); cases.add(structCase("SL", Map.of(JAVA_LONG.withName("x"), 42L), critical));
cases.add(structCase("SLL", Map.of(JAVA_LONG.withName("x"), 42L, cases.add(structCase("SLL", Map.of(JAVA_LONG.withName("x"), 42L,
JAVA_LONG.withName("y"), 42L))); JAVA_LONG.withName("y"), 42L), critical));
cases.add(structCase("SLLL", Map.of(JAVA_LONG.withName("x"), 42L, cases.add(structCase("SLLL", Map.of(JAVA_LONG.withName("x"), 42L,
JAVA_LONG.withName("y"), 42L, JAVA_LONG.withName("y"), 42L,
JAVA_LONG.withName("z"), 42L))); JAVA_LONG.withName("z"), 42L), critical));
cases.add(structCase("SD", Map.of(JAVA_DOUBLE.withName("x"), 42D))); cases.add(structCase("SD", Map.of(JAVA_DOUBLE.withName("x"), 42D), critical));
cases.add(structCase("SDD", Map.of(JAVA_DOUBLE.withName("x"), 42D, cases.add(structCase("SDD", Map.of(JAVA_DOUBLE.withName("x"), 42D,
JAVA_DOUBLE.withName("y"), 42D))); JAVA_DOUBLE.withName("y"), 42D), critical));
cases.add(structCase("SDDD", Map.of(JAVA_DOUBLE.withName("x"), 42D, cases.add(structCase("SDDD", Map.of(JAVA_DOUBLE.withName("x"), 42D,
JAVA_DOUBLE.withName("y"), 42D, JAVA_DOUBLE.withName("y"), 42D,
JAVA_DOUBLE.withName("z"), 42D))); JAVA_DOUBLE.withName("z"), 42D), critical));
if (IS_WINDOWS) { if (IS_WINDOWS) {
cases.add(new SaveValuesCase("SetLastError", FunctionDescriptor.ofVoid(JAVA_INT), "GetLastError", o -> {})); cases.add(new SaveValuesCase("SetLastError", FunctionDescriptor.ofVoid(JAVA_INT),
cases.add(new SaveValuesCase("WSASetLastError", FunctionDescriptor.ofVoid(JAVA_INT), "WSAGetLastError", o -> {})); "GetLastError", o -> {}, critical));
cases.add(new SaveValuesCase("WSASetLastError", FunctionDescriptor.ofVoid(JAVA_INT),
"WSAGetLastError", o -> {}, critical));
}
} }
return cases.stream().map(tc -> new Object[] {tc}).toArray(Object[][]::new); return cases.stream().map(tc -> new Object[] {tc}).toArray(Object[][]::new);
} }
static SaveValuesCase structCase(String name, Map<MemoryLayout, Object> fields) { static SaveValuesCase structCase(String name, Map<MemoryLayout, Object> fields, boolean critical) {
StructLayout layout = MemoryLayout.structLayout(fields.keySet().toArray(MemoryLayout[]::new)); StructLayout layout = MemoryLayout.structLayout(fields.keySet().toArray(MemoryLayout[]::new));
Consumer<Object> check = o -> {}; Consumer<Object> check = o -> {};
@ -139,16 +157,19 @@ public class TestCaptureCallState extends NativeTestHelper {
check = check.andThen(o -> assertEquals(fieldHandle.get(o, 0L), value)); check = check.andThen(o -> assertEquals(fieldHandle.get(o, 0L), value));
} }
return new SaveValuesCase("set_errno_" + name, FunctionDescriptor.of(layout, JAVA_INT), "errno", check); return new SaveValuesCase("set_errno_" + name, FunctionDescriptor.of(layout, JAVA_INT),
"errno", check, critical);
} }
@DataProvider @DataProvider
public static Object[][] invalidCaptureSegmentCases() { public static Object[][] invalidCaptureSegmentCases() {
return new Object[][]{ return new Object[][]{
{Arena.ofAuto().allocate(1), IndexOutOfBoundsException.class, ".*Out of bound access on segment.*"}, {Arena.ofAuto().allocate(1), IndexOutOfBoundsException.class, ".*Out of bound access on segment.*", new Linker.Option[0]},
{MemorySegment.NULL, IllegalArgumentException.class, ".*Capture segment is NULL.*"}, {MemorySegment.NULL, IllegalArgumentException.class, ".*Capture segment is NULL.*", new Linker.Option[0]},
{Arena.ofAuto().allocate(Linker.Option.captureStateLayout().byteSize() + 3).asSlice(3), // misaligned {Arena.ofAuto().allocate(Linker.Option.captureStateLayout().byteSize() + 3).asSlice(3), // misaligned
IllegalArgumentException.class, ".*Target offset incompatible with alignment constraints.*"}, IllegalArgumentException.class, ".*Target offset incompatible with alignment constraints.*", new Linker.Option[0]},
{MemorySegment.ofArray(new byte[(int) Linker.Option.captureStateLayout().byteSize()]), // misaligned
IllegalArgumentException.class, ".*Target offset incompatible with alignment constraints.*", new Linker.Option[0]},
}; };
} }
} }

View file

@ -45,12 +45,16 @@ import java.lang.invoke.VarHandle;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.IntFunction; import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertEquals;
public class TestCritical extends NativeTestHelper { public class TestCritical extends NativeTestHelper {
static final MemoryLayout CAPTURE_STATE_LAYOUT = Linker.Option.captureStateLayout();
static final VarHandle ERRNO_HANDLE = CAPTURE_STATE_LAYOUT.varHandle(MemoryLayout.PathElement.groupElement("errno"));
static { static {
System.loadLibrary("Critical"); System.loadLibrary("Critical");
} }
@ -87,11 +91,16 @@ public class TestCritical extends NativeTestHelper {
} }
public record AllowHeapCase(IntFunction<MemorySegment> newArraySegment, ValueLayout elementLayout, public record AllowHeapCase(IntFunction<MemorySegment> newArraySegment, ValueLayout elementLayout,
String fName, FunctionDescriptor fDesc, boolean readOnly) {} String fName, FunctionDescriptor fDesc, boolean readOnly, boolean captureErrno) {}
@Test(dataProvider = "allowHeapCases") @Test(dataProvider = "allowHeapCases")
public void testAllowHeap(AllowHeapCase testCase) throws Throwable { public void testAllowHeap(AllowHeapCase testCase) throws Throwable {
MethodHandle handle = downcallHandle(testCase.fName(), testCase.fDesc(), Linker.Option.critical(true)); List<Linker.Option> options = new ArrayList<>();
options.add(Linker.Option.critical(true));
if (testCase.captureErrno()) {
options.add(Linker.Option.captureCallState("errno"));
}
MethodHandle handle = downcallHandle(testCase.fName(), testCase.fDesc(), options.toArray(Linker.Option[]::new));
int elementCount = 10; int elementCount = 10;
MemorySegment heapSegment = testCase.newArraySegment().apply(elementCount); MemorySegment heapSegment = testCase.newArraySegment().apply(elementCount);
if (testCase.readOnly()) { if (testCase.readOnly()) {
@ -101,29 +110,36 @@ public class TestCritical extends NativeTestHelper {
try (Arena arena = Arena.ofConfined()) { try (Arena arena = Arena.ofConfined()) {
TestValue[] tvs = genTestArgs(testCase.fDesc(), arena); TestValue[] tvs = genTestArgs(testCase.fDesc(), arena);
Object[] args = Stream.of(tvs).map(TestValue::value).toArray(); List<Object> args = Stream.of(tvs).map(TestValue::value).collect(Collectors.toCollection(ArrayList::new));
MemorySegment captureSegment = testCase.captureErrno()
? MemorySegment.ofArray(new int[((int) CAPTURE_STATE_LAYOUT.byteSize() + 3) / 4])
: null;
// inject our custom last three arguments // inject our custom last three arguments
args[args.length - 1] = (int) sequence.byteSize(); args.set(args.size() - 1, (int) sequence.byteSize());
TestValue sourceSegment = genTestValue(sequence, arena); TestValue sourceSegment = genTestValue(sequence, arena);
args[args.length - 2] = sourceSegment.value(); args.set(args.size() - 2, sourceSegment.value());
args[args.length - 3] = heapSegment; args.set(args.size() - 3, heapSegment);
if (testCase.captureErrno()) {
args.add(0, captureSegment);
}
if (handle.type().parameterType(0) == SegmentAllocator.class) { if (handle.type().parameterType(0) == SegmentAllocator.class) {
Object[] newArgs = new Object[args.length + 1]; args.add(0, arena);
newArgs[0] = arena;
System.arraycopy(args, 0, newArgs, 1, args.length);
args = newArgs;
} }
Object o = handle.invokeWithArguments(args); Object o = handle.invokeWithArguments(args);
if (o != null) { if (o != null) {
tvs[0].check(o); tvs[0].check(o);
} }
// check that writes went through to array // check that writes went through to array
sourceSegment.check(heapSegment); sourceSegment.check(heapSegment);
if (testCase.captureErrno()) {
int errno = (int) ERRNO_HANDLE.get(captureSegment, 0L);
assertEquals(errno, 42);
}
} }
} }
@ -149,14 +165,16 @@ public class TestCritical extends NativeTestHelper {
List<AllowHeapCase> cases = new ArrayList<>(); List<AllowHeapCase> cases = new ArrayList<>();
for (HeapSegmentFactory hsf : HeapSegmentFactory.values()) { for (boolean doCapture : new boolean[]{ true, false }) {
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_void", voidDesc, false)); for (HeapSegmentFactory hsf : HeapSegmentFactory.values()) {
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_int", intDesc, false)); cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_void", voidDesc, false, doCapture));
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_return_buffer", L2Desc, false)); cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_int", intDesc, false, doCapture));
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_imr", L3Desc, false)); cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_return_buffer", L2Desc, false, doCapture));
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_void_stack", stackDesc, false)); cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_imr", L3Desc, false, doCapture));
// readOnly cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_void_stack", stackDesc, false, doCapture));
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_void", voidDesc, true)); // readOnly
cases.add(new AllowHeapCase(hsf.newArray, hsf.elementLayout, "test_allow_heap_void", voidDesc, true, doCapture));
}
} }
return cases.stream().map(e -> new Object[]{ e }).toArray(Object[][]::new); return cases.stream().map(e -> new Object[]{ e }).toArray(Object[][]::new);

View file

@ -53,12 +53,14 @@ EXPORT void test_allow_heap_void(unsigned char* heapArr, unsigned char* nativeAr
for (int i = 0; i < numBytes; i++) { for (int i = 0; i < numBytes; i++) {
heapArr[i] = nativeArr[i]; heapArr[i] = nativeArr[i];
} }
errno = 42;
} }
EXPORT int test_allow_heap_int(int a0, unsigned char* heapArr, unsigned char* nativeArr, int numBytes) { EXPORT int test_allow_heap_int(int a0, unsigned char* heapArr, unsigned char* nativeArr, int numBytes) {
for (int i = 0; i < numBytes; i++) { for (int i = 0; i < numBytes; i++) {
heapArr[i] = nativeArr[i]; heapArr[i] = nativeArr[i];
} }
errno = 42;
return a0; return a0;
} }
@ -71,6 +73,7 @@ EXPORT struct L2 test_allow_heap_return_buffer(struct L2 a0, unsigned char* heap
for (int i = 0; i < numBytes; i++) { for (int i = 0; i < numBytes; i++) {
heapArr[i] = nativeArr[i]; heapArr[i] = nativeArr[i];
} }
errno = 42;
return a0; return a0;
} }
@ -84,6 +87,7 @@ EXPORT struct L3 test_allow_heap_imr(struct L3 a0, unsigned char* heapArr, unsig
for (int i = 0; i < numBytes; i++) { for (int i = 0; i < numBytes; i++) {
heapArr[i] = nativeArr[i]; heapArr[i] = nativeArr[i];
} }
errno = 42;
return a0; return a0;
} }
@ -94,4 +98,5 @@ EXPORT void test_allow_heap_void_stack(long long a0, long long a1, long long a2,
for (int i = 0; i < numBytes; i++) { for (int i = 0; i < numBytes; i++) {
heapArr[i] = nativeArr[i]; heapArr[i] = nativeArr[i];
} }
errno = 42;
} }

View file

@ -31,8 +31,7 @@ import org.testng.annotations.DataProvider;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import java.io.IOException; import java.io.IOException;
import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.*;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
import static java.lang.foreign.ValueLayout.ADDRESS; import static java.lang.foreign.ValueLayout.ADDRESS;
@ -51,6 +50,19 @@ public class TestPassHeapSegment extends UpcallTestHelper {
handle.invoke(segment); // should throw handle.invoke(segment); // should throw
} }
@Test(expectedExceptions = IllegalArgumentException.class,
expectedExceptionsMessageRegExp = ".*Heap segment not allowed.*")
public void testNoHeapCaptureCallState() throws Throwable {
MethodHandle handle = downcallHandle("test_args", FunctionDescriptor.ofVoid(ADDRESS),
Linker.Option.captureCallState("errno"));
try (Arena arena = Arena.ofConfined()) {
assert Linker.Option.captureStateLayout().byteAlignment() % 4 == 0;
MemorySegment captureHeap = MemorySegment.ofArray(new int[(int) Linker.Option.captureStateLayout().byteSize() / 4]);
MemorySegment segment = arena.allocateFrom(C_CHAR, new byte[]{ 0, 1, 2 });
handle.invoke(captureHeap, segment); // should throw for captureHeap
}
}
@Test(dataProvider = "specs") @Test(dataProvider = "specs")
public void testNoHeapReturns(boolean spec) throws IOException, InterruptedException { public void testNoHeapReturns(boolean spec) throws IOException, InterruptedException {
runInNewProcess(Runner.class, spec) runInNewProcess(Runner.class, spec)