8338591: Improve performance of MemorySegment::copy

Reviewed-by: mcimadamore
This commit is contained in:
Per Minborg 2024-09-05 13:10:24 +00:00
parent cb9f5c5791
commit 6be927260a
4 changed files with 293 additions and 13 deletions

View file

@ -1570,8 +1570,9 @@ public sealed interface MemorySegment permits AbstractMemorySegmentImpl {
@ForceInline @ForceInline
static void copy(MemorySegment srcSegment, long srcOffset, static void copy(MemorySegment srcSegment, long srcOffset,
MemorySegment dstSegment, long dstOffset, long bytes) { MemorySegment dstSegment, long dstOffset, long bytes) {
copy(srcSegment, ValueLayout.JAVA_BYTE, srcOffset,
dstSegment, ValueLayout.JAVA_BYTE, dstOffset, AbstractMemorySegmentImpl.copy((AbstractMemorySegmentImpl) srcSegment, srcOffset,
(AbstractMemorySegmentImpl) dstSegment, dstOffset,
bytes); bytes);
} }

View file

@ -304,20 +304,25 @@ public abstract sealed class AbstractMemorySegmentImpl
@Override @Override
public final Optional<MemorySegment> asOverlappingSlice(MemorySegment other) { public final Optional<MemorySegment> asOverlappingSlice(MemorySegment other) {
AbstractMemorySegmentImpl that = (AbstractMemorySegmentImpl)Objects.requireNonNull(other); final AbstractMemorySegmentImpl that = (AbstractMemorySegmentImpl)Objects.requireNonNull(other);
if (unsafeGetBase() == that.unsafeGetBase()) { // both either native or heap if (overlaps(that)) {
final long offsetToThat = that.address() - this.address();
final long newOffset = offsetToThat >= 0 ? offsetToThat : 0;
return Optional.of(asSlice(newOffset, Math.min(this.byteSize() - newOffset, that.byteSize() + offsetToThat)));
}
return Optional.empty();
}
@ForceInline
private boolean overlaps(AbstractMemorySegmentImpl that) {
if (unsafeGetBase() == that.unsafeGetBase()) { // both either native or the same heap segment
final long thisStart = this.unsafeGetOffset(); final long thisStart = this.unsafeGetOffset();
final long thatStart = that.unsafeGetOffset(); final long thatStart = that.unsafeGetOffset();
final long thisEnd = thisStart + this.byteSize(); final long thisEnd = thisStart + this.byteSize();
final long thatEnd = thatStart + that.byteSize(); final long thatEnd = thatStart + that.byteSize();
return (thisStart < thatEnd && thisEnd > thatStart); //overlap occurs?
if (thisStart < thatEnd && thisEnd > thatStart) { //overlap occurs
long offsetToThat = that.address() - this.address();
long newOffset = offsetToThat >= 0 ? offsetToThat : 0;
return Optional.of(asSlice(newOffset, Math.min(this.byteSize() - newOffset, that.byteSize() + offsetToThat)));
}
} }
return Optional.empty(); return false;
} }
@Override @Override
@ -645,6 +650,64 @@ public abstract sealed class AbstractMemorySegmentImpl
} }
} }
// COPY_NATIVE_THRESHOLD must be a power of two and should be greater than 2^3
private static final long COPY_NATIVE_THRESHOLD = 1 << 6;
@ForceInline
public static void copy(AbstractMemorySegmentImpl src, long srcOffset,
AbstractMemorySegmentImpl dst, long dstOffset,
long size) {
Utils.checkNonNegativeIndex(size, "size");
// Implicit null check for src and dst
src.checkAccess(srcOffset, size, true);
dst.checkAccess(dstOffset, size, false);
if (size <= 0) {
// Do nothing
} else if (size < COPY_NATIVE_THRESHOLD && !src.overlaps(dst)) {
// 0 < size < FILL_NATIVE_LIMIT : 0...0X...XXXX
//
// Strictly, we could check for !src.asSlice(srcOffset, size).overlaps(dst.asSlice(dstOffset, size) but
// this is a bit slower and it likely very unusual there is any difference in the outcome. Also, if there
// is an overlap, we could tolerate one particular direction of overlap (but not the other).
// 0...0X...X000
final int limit = (int) (size & (COPY_NATIVE_THRESHOLD - 8));
int offset = 0;
for (; offset < limit; offset += 8) {
final long v = SCOPED_MEMORY_ACCESS.getLong(src.sessionImpl(), src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset + offset);
SCOPED_MEMORY_ACCESS.putLong(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v);
}
int remaining = (int) size - offset;
// 0...0X00
if (remaining >= 4) {
final int v = SCOPED_MEMORY_ACCESS.getInt(src.sessionImpl(), src.unsafeGetBase(),src.unsafeGetOffset() + srcOffset + offset);
SCOPED_MEMORY_ACCESS.putInt(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v);
offset += 4;
remaining -= 4;
}
// 0...00X0
if (remaining >= 2) {
final short v = SCOPED_MEMORY_ACCESS.getShort(src.sessionImpl(), src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset + offset);
SCOPED_MEMORY_ACCESS.putShort(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v);
offset += 2;
remaining -=2;
}
// 0...000X
if (remaining == 1) {
final byte v = SCOPED_MEMORY_ACCESS.getByte(src.sessionImpl(), src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset + offset);
SCOPED_MEMORY_ACCESS.putByte(dst.sessionImpl(), dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset + offset, v);
}
// We have now fully handled 0...0X...XXXX
} else {
// For larger sizes, the transition to native code pays off
SCOPED_MEMORY_ACCESS.copyMemory(src.sessionImpl(), dst.sessionImpl(),
src.unsafeGetBase(), src.unsafeGetOffset() + srcOffset,
dst.unsafeGetBase(), dst.unsafeGetOffset() + dstOffset, size);
}
}
@ForceInline @ForceInline
public static void copy(MemorySegment srcSegment, ValueLayout srcElementLayout, long srcOffset, public static void copy(MemorySegment srcSegment, ValueLayout srcElementLayout, long srcOffset,
MemorySegment dstSegment, ValueLayout dstElementLayout, long dstOffset, MemorySegment dstSegment, ValueLayout dstElementLayout, long dstOffset,

View file

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
@ -30,7 +30,6 @@
import java.lang.foreign.Arena; import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout; import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle; import java.lang.invoke.VarHandle;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.ArrayList; import java.util.ArrayList;
@ -76,6 +75,94 @@ public class TestSegmentCopy {
} }
} }
@Test(dataProvider = "conjunctSegments")
public void testCopy5ArgInvariants(MemorySegment src, MemorySegment dst) {
assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 0, dst, 0, -1));
assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, -1, dst, 0, src.byteSize()));
assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 0, dst, -1, src.byteSize()));
assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 1, dst, 0, src.byteSize()));
assertThrows(IndexOutOfBoundsException.class, () -> MemorySegment.copy(src, 0, dst, 1, src.byteSize()));
}
@Test(dataProvider = "conjunctSegments")
public void testConjunctCopy7ArgRight(MemorySegment src, MemorySegment dst) {
testConjunctCopy(src, 0, dst, 1, CopyOp.of7Arg());
}
@Test(dataProvider = "conjunctSegments")
public void testConjunctCopy5ArgRight(MemorySegment src, MemorySegment dst) {
testConjunctCopy(src, 0, dst, 1, CopyOp.of5Arg());
}
@Test(dataProvider = "conjunctSegments")
public void testConjunctCopy7ArgLeft(MemorySegment src, MemorySegment dst) {
testConjunctCopy(src, 1, dst, 0, CopyOp.of7Arg());
}
@Test(dataProvider = "conjunctSegments")
public void testConjunctCopy5ArgLeft(MemorySegment src, MemorySegment dst) {
testConjunctCopy(src, 1, dst, 0, CopyOp.of5Arg());
}
void testConjunctCopy(MemorySegment src, long srcOffset, MemorySegment dst, long dstOffset, CopyOp op) {
if (src.byteSize() < 4 || src.address() != dst.address()) {
// Only test larger segments where the skew is zero
return;
}
try (var arena = Arena.ofConfined()) {
// Create a disjoint segment for expected behavior
MemorySegment disjoint = arena.allocate(dst.byteSize());
disjoint.copyFrom(src);
op.copy(src, srcOffset, disjoint, dstOffset, 3);
byte[] expected = disjoint.toArray(JAVA_BYTE);
// Do a conjoint copy
op.copy(src, srcOffset, dst, dstOffset, 3);
byte[] actual = dst.toArray(JAVA_BYTE);
assertEquals(actual, expected);
}
}
@FunctionalInterface
interface CopyOp {
void copy(MemorySegment src, long srcOffset, MemorySegment dst, long dstOffset, long bytes);
static CopyOp of5Arg() {
return MemorySegment::copy;
}
static CopyOp of7Arg() {
return (MemorySegment src, long srcOffset, MemorySegment dst, long dstOffset, long bytes) ->
MemorySegment.copy(src, JAVA_BYTE, srcOffset, dst, JAVA_BYTE, dstOffset, bytes);
}
}
@Test(dataProvider = "segmentKinds")
public void testByteCopySizes(SegmentKind kind1, SegmentKind kind2) {
record Offsets(int src, int dst){}
for (Offsets offsets : List.of(new Offsets(3, 7), new Offsets(7, 3))) {
for (int size = 0; size < 513; size++) {
MemorySegment src = kind1.makeSegment(size + offsets.src());
MemorySegment dst = kind2.makeSegment(size + offsets.dst());
//prepare source slice
for (int i = 0; i < size; i++) {
src.set(JAVA_BYTE, i + offsets.src(), (byte) i);
}
//perform copy
MemorySegment.copy(src, offsets.src(), dst, offsets.dst(), size);
//check that copy actually worked
for (int i = 0; i < size; i++) {
assertEquals(dst.get(JAVA_BYTE, i + offsets.dst()), (byte) i);
}
}
}
}
@Test(expectedExceptions = IllegalArgumentException.class, dataProvider = "segmentKinds") @Test(expectedExceptions = IllegalArgumentException.class, dataProvider = "segmentKinds")
public void testReadOnlyCopy(SegmentKind kind1, SegmentKind kind2) { public void testReadOnlyCopy(SegmentKind kind1, SegmentKind kind2) {
MemorySegment s1 = kind1.makeSegment(TEST_BYTE_SIZE); MemorySegment s1 = kind1.makeSegment(TEST_BYTE_SIZE);
@ -277,6 +364,28 @@ public class TestSegmentCopy {
return cases.toArray(Object[][]::new); return cases.toArray(Object[][]::new);
} }
@DataProvider
static Object[][] conjunctSegments() {
List<Object[]> cases = new ArrayList<>();
for (SegmentKind kind : SegmentKind.values()) {
// Different paths might be taken in the implementation depending on the
// size, type, and address of the underlying segments.
for (int len : new int[]{0, 1, 7, 512}) {
for (int offset : new int[]{-1, 0, 1}) {
MemorySegment segment = kind.makeSegment(len + 2);
MemorySegment src = segment.asSlice(1 + offset, len);
MemorySegment dst = segment.asSlice(1, len);
for (int i = 0; i < len; i++) {
src.set(JAVA_BYTE, i, (byte) i);
}
// src = 0, 1, ... , len-1
cases.add(new Object[]{src, dst});
}
}
}
return cases.toArray(Object[][]::new);
}
@DataProvider @DataProvider
static Object[][] types() { static Object[][] types() {
return Arrays.stream(Type.values()) return Arrays.stream(Type.values())

View file

@ -0,0 +1,107 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*
*/
package org.openjdk.bench.java.lang.foreign;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import static java.lang.foreign.ValueLayout.*;
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(value = 3)
public class CopyTest {
@Param({"0", "1", "2", "3", "4", "5", "6", "7", "8",
"9", "10", "11", "12", "13", "14", "15", "16",
"17", "18", "19", "20", "21", "22", "23", "24",
"25", "26", "27", "28", "29", "30", "31", "32",
"33", "36", "40", "44", "48", "52", "56", "60", "63", "64", "128"})
public int ELEM_SIZE;
byte[] srcArray;
byte[] dstArray;
MemorySegment heapSrcSegment;
MemorySegment heapDstSegment;
MemorySegment nativeSrcSegment;
MemorySegment nativeDstSegment;
ByteBuffer srcBuffer;
ByteBuffer dstBuffer;
@Setup
public void setup() {
srcArray = new byte[ELEM_SIZE];
dstArray = new byte[ELEM_SIZE];
heapSrcSegment = MemorySegment.ofArray(srcArray);
heapDstSegment = MemorySegment.ofArray(dstArray);
nativeSrcSegment = Arena.ofAuto().allocate(ELEM_SIZE);
nativeDstSegment = Arena.ofAuto().allocate(ELEM_SIZE);
srcBuffer = ByteBuffer.wrap(srcArray);
dstBuffer = ByteBuffer.wrap(dstArray);
}
@Benchmark
public void array_copy() {
System.arraycopy(srcArray, 0, dstArray, 0, ELEM_SIZE);
}
@Benchmark
public void heap_segment_copy5Arg() {
MemorySegment.copy(heapSrcSegment, 0, heapDstSegment, 0, ELEM_SIZE);
}
@Benchmark
public void native_segment_copy5Arg() {
MemorySegment.copy(nativeSrcSegment, 0, nativeDstSegment, 0, ELEM_SIZE);
}
@Benchmark
public void heap_segment_copy7arg() {
MemorySegment.copy(heapSrcSegment, JAVA_BYTE, 0, heapDstSegment, JAVA_BYTE, 0, ELEM_SIZE);
}
@Benchmark
public void buffer_copy() {
dstBuffer.put(srcBuffer);
}
}