diff --git a/src/java.base/share/classes/java/lang/foreign/MemorySegment.java b/src/java.base/share/classes/java/lang/foreign/MemorySegment.java index 8f171fd8bfb..38fd36bbb15 100644 --- a/src/java.base/share/classes/java/lang/foreign/MemorySegment.java +++ b/src/java.base/share/classes/java/lang/foreign/MemorySegment.java @@ -1570,8 +1570,9 @@ public sealed interface MemorySegment permits AbstractMemorySegmentImpl { @ForceInline static void copy(MemorySegment srcSegment, long srcOffset, MemorySegment dstSegment, long dstOffset, long bytes) { - copy(srcSegment, ValueLayout.JAVA_BYTE, srcOffset, - dstSegment, ValueLayout.JAVA_BYTE, dstOffset, + + AbstractMemorySegmentImpl.copy((AbstractMemorySegmentImpl) srcSegment, srcOffset, + (AbstractMemorySegmentImpl) dstSegment, dstOffset, bytes); } diff --git a/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java b/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java index 5d43c28a667..83b11b7ce68 100644 --- a/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java +++ b/src/java.base/share/classes/jdk/internal/foreign/AbstractMemorySegmentImpl.java @@ -304,20 +304,25 @@ public abstract sealed class AbstractMemorySegmentImpl @Override public final Optional asOverlappingSlice(MemorySegment other) { - AbstractMemorySegmentImpl that = (AbstractMemorySegmentImpl)Objects.requireNonNull(other); - if (unsafeGetBase() == that.unsafeGetBase()) { // both either native or heap + final AbstractMemorySegmentImpl that = (AbstractMemorySegmentImpl)Objects.requireNonNull(other); + if (overlaps(that)) { + final long offsetToThat = that.address() - this.address(); + final long newOffset = offsetToThat >= 0 ? offsetToThat : 0; + return Optional.of(asSlice(newOffset, Math.min(this.byteSize() - newOffset, that.byteSize() + offsetToThat))); + } + return Optional.empty(); + } + + @ForceInline + private boolean overlaps(AbstractMemorySegmentImpl that) { + if (unsafeGetBase() == that.unsafeGetBase()) { // both either native or the same heap segment final long thisStart = this.unsafeGetOffset(); final long thatStart = that.unsafeGetOffset(); final long thisEnd = thisStart + this.byteSize(); final long thatEnd = thatStart + that.byteSize(); - - if (thisStart < thatEnd && thisEnd > thatStart) { //overlap occurs - long offsetToThat = that.address() - this.address(); - long newOffset = offsetToThat >= 0 ? offsetToThat : 0; - return Optional.of(asSlice(newOffset, Math.min(this.byteSize() - newOffset, that.byteSize() + offsetToThat))); - } + return (thisStart < thatEnd && thisEnd > thatStart); //overlap occurs? } - return Optional.empty(); + return false; } @Override @@ -645,6 +650,64 @@ public abstract sealed class AbstractMemorySegmentImpl } } + // COPY_NATIVE_THRESHOLD must be a power of two and should be greater than 2^3 + private static final long COPY_NATIVE_THRESHOLD = 1 << 6; + + @ForceInline + public static void copy(AbstractMemorySegmentImpl src, long srcOffset, + AbstractMemorySegmentImpl dst, long dstOffset, + long size) { + + Utils.checkNonNegativeIndex(size, "size"); + // Implicit null check for src and dst + src.checkAccess(srcOffset, size, true); + dst.checkAccess(dstOffset, size, false); + + if (size <= 0) { + // Do nothing + } else if (size < COPY_NATIVE_THRESHOLD && !src.overlaps(dst)) { + // 0 < size < FILL_NATIVE_LIMIT : 0...0X...XXXX + // + // Strictly, we could check for !src.asSlice(srcOffset, size).overlaps(dst.asSlice(dstOffset, size) but + // this is a bit slower and it likely very unusual there is any difference in the outcome. Also, if there + // is an overlap, we could tolerate one particular direction of overlap (but not the other). + + // 0...0X...X000 + final int limit = (int) (size & (COPY_NATIVE_THRESHOLD - 8)); + int offset = 0; + for (; offset < limit; offset += 8) { + final long v = SCOPED_MEMORY_ACCESS.getLong(src.sessionImpl(), src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset + offset); + SCOPED_MEMORY_ACCESS.putLong(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v); + } + int remaining = (int) size - offset; + // 0...0X00 + if (remaining >= 4) { + final int v = SCOPED_MEMORY_ACCESS.getInt(src.sessionImpl(), src.unsafeGetBase(),src.unsafeGetOffset() + srcOffset + offset); + SCOPED_MEMORY_ACCESS.putInt(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v); + offset += 4; + remaining -= 4; + } + // 0...00X0 + if (remaining >= 2) { + final short v = SCOPED_MEMORY_ACCESS.getShort(src.sessionImpl(), src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset + offset); + SCOPED_MEMORY_ACCESS.putShort(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v); + offset += 2; + remaining -=2; + } + // 0...000X + if (remaining == 1) { + final byte v = SCOPED_MEMORY_ACCESS.getByte(src.sessionImpl(), src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset + offset); + SCOPED_MEMORY_ACCESS.putByte(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v); + } + // We have now fully handled 0...0X...XXXX + } else { + // For larger sizes, the transition to native code pays off + SCOPED_MEMORY_ACCESS.copyMemory(src.sessionImpl(), dst.sessionImpl(), + src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset, + dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset, size); + } + } + @ForceInline public static void copy(MemorySegment srcSegment, ValueLayout srcElementLayout, long srcOffset, MemorySegment dstSegment, ValueLayout dstElementLayout, long dstOffset, diff --git a/test/jdk/java/foreign/TestSegmentCopy.java b/test/jdk/java/foreign/TestSegmentCopy.java index 88636bf5420..9a4500b2f5a 100644 --- a/test/jdk/java/foreign/TestSegmentCopy.java +++ b/test/jdk/java/foreign/TestSegmentCopy.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,7 +30,6 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.nio.ByteOrder; import java.util.ArrayList; @@ -76,6 +75,94 @@ public class TestSegmentCopy { } } + @Test(dataProvider = "conjunctSegments") + public void testCopy5ArgInvariants(MemorySegment src, MemorySegment dst) { + assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 0, dst, 0, -1)); + assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, -1, dst, 0, src.byteSize())); + assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 0, dst, -1, src.byteSize())); + assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 1, dst, 0, src.byteSize())); + assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 0, dst, 1, src.byteSize())); + } + + @Test(dataProvider = "conjunctSegments") + public void testConjunctCopy7ArgRight(MemorySegment src, MemorySegment dst) { + testConjunctCopy(src, 0, dst, 1, CopyOp.of7Arg()); + } + + @Test(dataProvider = "conjunctSegments") + public void testConjunctCopy5ArgRight(MemorySegment src, MemorySegment dst) { + testConjunctCopy(src, 0, dst, 1, CopyOp.of5Arg()); + } + + @Test(dataProvider = "conjunctSegments") + public void testConjunctCopy7ArgLeft(MemorySegment src, MemorySegment dst) { + testConjunctCopy(src, 1, dst, 0, CopyOp.of7Arg()); + } + + @Test(dataProvider = "conjunctSegments") + public void testConjunctCopy5ArgLeft(MemorySegment src, MemorySegment dst) { + testConjunctCopy(src, 1, dst, 0, CopyOp.of5Arg()); + } + + void testConjunctCopy(MemorySegment src, long srcOffset, MemorySegment dst, long dstOffset, CopyOp op) { + if (src.byteSize() < 4 || src.address() != dst.address()) { + // Only test larger segments where the skew is zero + return; + } + + try (var arena = Arena.ofConfined()) { + // Create a disjoint segment for expected behavior + MemorySegment disjoint = arena.allocate(dst.byteSize()); + disjoint.copyFrom(src); + op.copy(src, srcOffset, disjoint, dstOffset, 3); + byte[] expected = disjoint.toArray(JAVA_BYTE); + + // Do a conjoint copy + op.copy(src, srcOffset, dst, dstOffset, 3); + byte[] actual = dst.toArray(JAVA_BYTE); + + assertEquals(actual, expected); + } + } + + @FunctionalInterface + interface CopyOp { + void copy(MemorySegment src, long srcOffset, MemorySegment dst, long dstOffset, long bytes); + + static CopyOp of5Arg() { + return MemorySegment::copy; + } + + static CopyOp of7Arg() { + return (MemorySegment src, long srcOffset, MemorySegment dst, long dstOffset, long bytes) -> + MemorySegment.copy(src, JAVA_BYTE, srcOffset, dst, JAVA_BYTE, dstOffset, bytes); + } + + } + + @Test(dataProvider = "segmentKinds") + public void testByteCopySizes(SegmentKind kind1, SegmentKind kind2) { + + record Offsets(int src, int dst){} + + for (Offsets offsets : List.of(new Offsets(3, 7), new Offsets(7, 3))) { + for (int size = 0; size < 513; size++) { + MemorySegment src = kind1.makeSegment(size + offsets.src()); + MemorySegment dst = kind2.makeSegment(size + offsets.dst()); + //prepare source slice + for (int i = 0; i < size; i++) { + src.set(JAVA_BYTE, i + offsets.src(), (byte) i); + } + //perform copy + MemorySegment.copy(src, offsets.src(), dst, offsets.dst(), size); + //check that copy actually worked + for (int i = 0; i < size; i++) { + assertEquals(dst.get(JAVA_BYTE, i + offsets.dst()), (byte) i); + } + } + } + } + @Test(expectedExceptions = IllegalArgumentException.class, dataProvider = "segmentKinds") public void testReadOnlyCopy(SegmentKind kind1, SegmentKind kind2) { MemorySegment s1 = kind1.makeSegment(TEST_BYTE_SIZE); @@ -277,6 +364,28 @@ public class TestSegmentCopy { return cases.toArray(Object[][]::new); } + @DataProvider + static Object[][] conjunctSegments() { + List cases = new ArrayList<>(); + for (SegmentKind kind : SegmentKind.values()) { + // Different paths might be taken in the implementation depending on the + // size, type, and address of the underlying segments. + for (int len : new int[]{0, 1, 7, 512}) { + for (int offset : new int[]{-1, 0, 1}) { + MemorySegment segment = kind.makeSegment(len + 2); + MemorySegment src = segment.asSlice(1 + offset, len); + MemorySegment dst = segment.asSlice(1, len); + for (int i = 0; i < len; i++) { + src.set(JAVA_BYTE, i, (byte) i); + } + // src = 0, 1, ... , len-1 + cases.add(new Object[]{src, dst}); + } + } + } + return cases.toArray(Object[][]::new); + } + @DataProvider static Object[][] types() { return Arrays.stream(Type.values()) diff --git a/test/micro/org/openjdk/bench/java/lang/foreign/CopyTest.java b/test/micro/org/openjdk/bench/java/lang/foreign/CopyTest.java new file mode 100644 index 00000000000..8996b1de117 --- /dev/null +++ b/test/micro/org/openjdk/bench/java/lang/foreign/CopyTest.java @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + * + */ + +package org.openjdk.bench.java.lang.foreign; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; + +import static java.lang.foreign.ValueLayout.*; + +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(value = 3) +public class CopyTest { + + @Param({"0", "1", "2", "3", "4", "5", "6", "7", "8", + "9", "10", "11", "12", "13", "14", "15", "16", + "17", "18", "19", "20", "21", "22", "23", "24", + "25", "26", "27", "28", "29", "30", "31", "32", + "33", "36", "40", "44", "48", "52", "56", "60", "63", "64", "128"}) + public int ELEM_SIZE; + + byte[] srcArray; + byte[] dstArray; + MemorySegment heapSrcSegment; + MemorySegment heapDstSegment; + MemorySegment nativeSrcSegment; + MemorySegment nativeDstSegment; + ByteBuffer srcBuffer; + ByteBuffer dstBuffer; + + @Setup + public void setup() { + srcArray = new byte[ELEM_SIZE]; + dstArray = new byte[ELEM_SIZE]; + heapSrcSegment = MemorySegment.ofArray(srcArray); + heapDstSegment = MemorySegment.ofArray(dstArray); + nativeSrcSegment = Arena.ofAuto().allocate(ELEM_SIZE); + nativeDstSegment = Arena.ofAuto().allocate(ELEM_SIZE); + srcBuffer = ByteBuffer.wrap(srcArray); + dstBuffer = ByteBuffer.wrap(dstArray); + } + + @Benchmark + public void array_copy() { + System.arraycopy(srcArray, 0, dstArray, 0, ELEM_SIZE); + } + + @Benchmark + public void heap_segment_copy5Arg() { + MemorySegment.copy(heapSrcSegment, 0, heapDstSegment, 0, ELEM_SIZE); + } + + @Benchmark + public void native_segment_copy5Arg() { + MemorySegment.copy(nativeSrcSegment, 0, nativeDstSegment, 0, ELEM_SIZE); + } + + @Benchmark + public void heap_segment_copy7arg() { + MemorySegment.copy(heapSrcSegment, JAVA_BYTE, 0, heapDstSegment, JAVA_BYTE, 0, ELEM_SIZE); + } + + @Benchmark + public void buffer_copy() { + dstBuffer.put(srcBuffer); + } + +}