8307375: Alignment check on layouts used as sequence element is not correct

Reviewed-by: jvernee
This commit is contained in:
Maurizio Cimadamore 2023-05-05 15:59:13 +00:00
parent 3968ab5db5
commit 47422be2d1
6 changed files with 85 additions and 24 deletions

View file

@ -704,12 +704,12 @@ public sealed interface MemoryLayout permits SequenceLayout, GroupLayout, Paddin
* @param elementLayout the sequence element layout. * @param elementLayout the sequence element layout.
* @return the new sequence layout with the given element layout and size. * @return the new sequence layout with the given element layout and size.
* @throws IllegalArgumentException if {@code elementCount } is negative. * @throws IllegalArgumentException if {@code elementCount } is negative.
* @throws IllegalArgumentException if {@code elementLayout.bitAlignment() > elementLayout.bitSize()}. * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}.
*/ */
static SequenceLayout sequenceLayout(long elementCount, MemoryLayout elementLayout) { static SequenceLayout sequenceLayout(long elementCount, MemoryLayout elementLayout) {
MemoryLayoutUtil.requireNonNegative(elementCount); MemoryLayoutUtil.requireNonNegative(elementCount);
Objects.requireNonNull(elementLayout); Objects.requireNonNull(elementLayout);
Utils.checkElementAlignment(elementLayout, "Element layout alignment greater than its size"); Utils.checkElementAlignment(elementLayout, "Element layout size is not multiple of alignment");
return wrapOverflow(() -> return wrapOverflow(() ->
SequenceLayoutImpl.of(elementCount, elementLayout)); SequenceLayoutImpl.of(elementCount, elementLayout));
} }
@ -725,7 +725,7 @@ public sealed interface MemoryLayout permits SequenceLayout, GroupLayout, Paddin
* *
* @param elementLayout the sequence element layout. * @param elementLayout the sequence element layout.
* @return a new sequence layout with the given element layout and maximum element count. * @return a new sequence layout with the given element layout and maximum element count.
* @throws IllegalArgumentException if {@code elementLayout.bitAlignment() > elementLayout.bitSize()}. * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}.
*/ */
static SequenceLayout sequenceLayout(MemoryLayout elementLayout) { static SequenceLayout sequenceLayout(MemoryLayout elementLayout) {
Objects.requireNonNull(elementLayout); Objects.requireNonNull(elementLayout);

View file

@ -459,10 +459,11 @@ public sealed interface MemorySegment permits AbstractMemorySegmentImpl {
* *
* @param elementLayout the layout to be used for splitting. * @param elementLayout the layout to be used for splitting.
* @return the element spliterator for this segment * @return the element spliterator for this segment
* @throws IllegalArgumentException if the {@code elementLayout} size is zero, or the segment size modulo the * @throws IllegalArgumentException if {@code elementLayout.byteSize() == 0}.
* {@code elementLayout} size is greater than zero, if this segment is * @throws IllegalArgumentException if {@code byteSize() % elementLayout.byteSize() != 0}.
* <a href="MemorySegment.html#segment-alignment">incompatible with the alignment constraint</a> in the provided layout, * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}.
* or if the {@code elementLayout} alignment is greater than its size. * @throws IllegalArgumentException if this segment is <a href="MemorySegment.html#segment-alignment">incompatible
* with the alignment constraint</a> in the provided layout.
*/ */
Spliterator<MemorySegment> spliterator(MemoryLayout elementLayout); Spliterator<MemorySegment> spliterator(MemoryLayout elementLayout);
@ -475,10 +476,11 @@ public sealed interface MemorySegment permits AbstractMemorySegmentImpl {
* *
* @param elementLayout the layout to be used for splitting. * @param elementLayout the layout to be used for splitting.
* @return a sequential {@code Stream} over disjoint slices in this segment. * @return a sequential {@code Stream} over disjoint slices in this segment.
* @throws IllegalArgumentException if the {@code elementLayout} size is zero, or the segment size modulo the * @throws IllegalArgumentException if {@code elementLayout.byteSize() == 0}.
* {@code elementLayout} size is greater than zero, if this segment is * @throws IllegalArgumentException if {@code byteSize() % elementLayout.byteSize() != 0}.
* <a href="MemorySegment.html#segment-alignment">incompatible with the alignment constraint</a> in the provided layout, * @throws IllegalArgumentException if {@code elementLayout.bitSize() % elementLayout.bitAlignment() != 0}.
* or if the {@code elementLayout} alignment is greater than its size. * @throws IllegalArgumentException if this segment is <a href="MemorySegment.html#segment-alignment">incompatible
* with the alignment constraint</a> in the provided layout.
*/ */
Stream<MemorySegment> elements(MemoryLayout elementLayout); Stream<MemorySegment> elements(MemoryLayout elementLayout);

View file

@ -165,7 +165,7 @@ public abstract sealed class AbstractMemorySegmentImpl
if (elementLayout.byteSize() == 0) { if (elementLayout.byteSize() == 0) {
throw new IllegalArgumentException("Element layout size cannot be zero"); throw new IllegalArgumentException("Element layout size cannot be zero");
} }
Utils.checkElementAlignment(elementLayout, "Element layout alignment greater than its size"); Utils.checkElementAlignment(elementLayout, "Element layout size is not multiple of alignment");
if (!isAlignedForElement(0, elementLayout)) { if (!isAlignedForElement(0, elementLayout)) {
throw new IllegalArgumentException("Incompatible alignment constraints"); throw new IllegalArgumentException("Incompatible alignment constraints");
} }

View file

@ -174,12 +174,22 @@ public final class Utils {
} }
@ForceInline @ForceInline
public static void checkElementAlignment(MemoryLayout layout, String msg) { public static void checkElementAlignment(ValueLayout layout, String msg) {
// Fast-path: if both size and alignment are powers of two, we can just
// check if one is greater than the other.
assert isPowerOfTwo(layout.bitSize());
if (layout.byteAlignment() > layout.byteSize()) { if (layout.byteAlignment() > layout.byteSize()) {
throw new IllegalArgumentException(msg); throw new IllegalArgumentException(msg);
} }
} }
@ForceInline
public static void checkElementAlignment(MemoryLayout layout, String msg) {
if (layout.byteSize() % layout.byteAlignment() != 0) {
throw new IllegalArgumentException(msg);
}
}
public static long pointeeByteSize(AddressLayout addressLayout) { public static long pointeeByteSize(AddressLayout addressLayout) {
return addressLayout.targetLayout() return addressLayout.targetLayout()
.map(MemoryLayout::byteSize) .map(MemoryLayout::byteSize)
@ -245,4 +255,8 @@ public final class Utils {
public static int byteWidthOfPrimitive(Class<?> primitive) { public static int byteWidthOfPrimitive(Class<?> primitive) {
return Wrapper.forPrimitiveType(primitive).bitWidth() / 8; return Wrapper.forPrimitiveType(primitive).bitWidth() / 8;
} }
public static boolean isPowerOfTwo(long value) {
return (value & (value - 1)) == 0L;
}
} }

View file

@ -25,6 +25,8 @@
*/ */
package jdk.internal.foreign.layout; package jdk.internal.foreign.layout;
import jdk.internal.foreign.Utils;
import java.lang.foreign.GroupLayout; import java.lang.foreign.GroupLayout;
import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemoryLayout;
import java.lang.foreign.SequenceLayout; import java.lang.foreign.SequenceLayout;
@ -138,8 +140,8 @@ public abstract sealed class AbstractLayout<L extends AbstractLayout<L> & Memory
} }
private static long requirePowerOfTwoAndGreaterOrEqualToEight(long value) { private static long requirePowerOfTwoAndGreaterOrEqualToEight(long value) {
if (((value & (value - 1)) != 0L) || // value must be a power of two if (!Utils.isPowerOfTwo(value) || // value must be a power of two
(value < 8)) { // value must be greater or equal to 8 value < 8) { // value must be greater or equal to 8
throw new IllegalArgumentException("Invalid alignment: " + value); throw new IllegalArgumentException("Invalid alignment: " + value);
} }
return value; return value;

View file

@ -31,6 +31,7 @@ import java.lang.foreign.*;
import java.lang.invoke.VarHandle; import java.lang.invoke.VarHandle;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.LongFunction; import java.util.function.LongFunction;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -292,11 +293,46 @@ public class TestLayouts {
} }
@Test(dataProvider="layoutsAndAlignments", expectedExceptions = IllegalArgumentException.class) @Test(dataProvider="layoutsAndAlignments", expectedExceptions = IllegalArgumentException.class)
public void testBadSequence(MemoryLayout layout, long bitAlign) { public void testBadSequenceElementAlignmentTooBig(MemoryLayout layout, long bitAlign) {
layout = layout.withBitAlignment(layout.bitSize() * 2); // hyper-align layout = layout.withBitAlignment(layout.bitSize() * 2); // hyper-align
MemoryLayout.sequenceLayout(layout); MemoryLayout.sequenceLayout(layout);
} }
@Test(dataProvider="layoutsAndAlignments")
public void testBadSequenceElementSizeNotMultipleOfAlignment(MemoryLayout layout, long bitAlign) {
boolean shouldFail = layout.byteSize() % layout.byteAlignment() != 0;
try {
MemoryLayout.sequenceLayout(layout);
assertFalse(shouldFail);
} catch (IllegalArgumentException ex) {
assertTrue(shouldFail);
}
}
@Test(dataProvider="layoutsAndAlignments")
public void testBadSpliteratorElementSizeNotMultipleOfAlignment(MemoryLayout layout, long bitAlign) {
boolean shouldFail = layout.byteSize() % layout.byteAlignment() != 0;
try (Arena arena = Arena.ofConfined()) {
MemorySegment segment = arena.allocate(layout);
segment.spliterator(layout);
assertFalse(shouldFail);
} catch (IllegalArgumentException ex) {
assertTrue(shouldFail);
}
}
@Test(dataProvider="layoutsAndAlignments")
public void testBadElementsElementSizeNotMultipleOfAlignment(MemoryLayout layout, long bitAlign) {
boolean shouldFail = layout.byteSize() % layout.byteAlignment() != 0;
try (Arena arena = Arena.ofConfined()) {
MemorySegment segment = arena.allocate(layout);
segment.elements(layout);
assertFalse(shouldFail);
} catch (IllegalArgumentException ex) {
assertTrue(shouldFail);
}
}
@Test(dataProvider="layoutsAndAlignments", expectedExceptions = IllegalArgumentException.class) @Test(dataProvider="layoutsAndAlignments", expectedExceptions = IllegalArgumentException.class)
public void testBadStruct(MemoryLayout layout, long bitAlign) { public void testBadStruct(MemoryLayout layout, long bitAlign) {
layout = layout.withBitAlignment(layout.bitSize() * 2); // hyper-align layout = layout.withBitAlignment(layout.bitSize() * 2); // hyper-align
@ -392,25 +428,32 @@ public class TestLayouts {
@DataProvider(name = "layoutsAndAlignments") @DataProvider(name = "layoutsAndAlignments")
public Object[][] layoutsAndAlignments() { public Object[][] layoutsAndAlignments() {
Object[][] layoutsAndAlignments = new Object[basicLayouts.length * 4][]; List<Object[]> layoutsAndAlignments = new ArrayList<>();
int i = 0; int i = 0;
//add basic layouts //add basic layouts
for (MemoryLayout l : basicLayouts) { for (MemoryLayout l : basicLayouts) {
layoutsAndAlignments[i++] = new Object[] { l, l.bitAlignment() }; layoutsAndAlignments.add(new Object[] { l, l.bitAlignment() });
} }
//add basic layouts wrapped in a sequence with given size //add basic layouts wrapped in a sequence with given size
for (MemoryLayout l : basicLayouts) { for (MemoryLayout l : basicLayouts) {
layoutsAndAlignments[i++] = new Object[] { MemoryLayout.sequenceLayout(4, l), l.bitAlignment() }; layoutsAndAlignments.add(new Object[] { MemoryLayout.sequenceLayout(4, l), l.bitAlignment() });
} }
//add basic layouts wrapped in a struct //add basic layouts wrapped in a struct
for (MemoryLayout l : basicLayouts) { for (MemoryLayout l1 : basicLayouts) {
layoutsAndAlignments[i++] = new Object[] { MemoryLayout.structLayout(l), l.bitAlignment() }; for (MemoryLayout l2 : basicLayouts) {
if (l1.byteSize() % l2.byteAlignment() != 0) continue; // second element is not aligned, skip
long align = Math.max(l1.bitAlignment(), l2.bitAlignment());
layoutsAndAlignments.add(new Object[]{MemoryLayout.structLayout(l1, l2), align});
}
} }
//add basic layouts wrapped in a union //add basic layouts wrapped in a union
for (MemoryLayout l : basicLayouts) { for (MemoryLayout l1 : basicLayouts) {
layoutsAndAlignments[i++] = new Object[] { MemoryLayout.unionLayout(l), l.bitAlignment() }; for (MemoryLayout l2 : basicLayouts) {
long align = Math.max(l1.bitAlignment(), l2.bitAlignment());
layoutsAndAlignments.add(new Object[]{MemoryLayout.unionLayout(l1, l2), align});
} }
return layoutsAndAlignments; }
return layoutsAndAlignments.toArray(Object[][]::new);
} }
@DataProvider(name = "groupLayouts") @DataProvider(name = "groupLayouts")