8278623: compiler/vectorapi/reshape/TestVectorCastAVX512.java after JDK-8259610

Reviewed-by: kvn, chagedorn, psandoz
This commit is contained in:
merykitty 2021-12-17 23:42:28 +00:00 committed by Vladimir Kozlov
parent 905b763942
commit cc44e13797
10 changed files with 385 additions and 330 deletions

View file

@ -74,8 +74,6 @@ compiler/whitebox/MakeMethodNotCompilableTest.java 8265360 macosx-aarch64
compiler/codecache/jmx/PoolsIndependenceTest.java 8264632 macosx-x64 compiler/codecache/jmx/PoolsIndependenceTest.java 8264632 macosx-x64
compiler/codecache/TestStressCodeBuffers.java 8272094 generic-aarch64 compiler/codecache/TestStressCodeBuffers.java 8272094 generic-aarch64
compiler/vectorapi/reshape/TestVectorCastAVX512.java 8278623 generic-x64
############################################################################# #############################################################################

View file

@ -0,0 +1,48 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package compiler.vectorapi.reshape;
import compiler.vectorapi.reshape.tests.TestVectorCast;
import compiler.vectorapi.reshape.utils.TestCastMethods;
import compiler.vectorapi.reshape.utils.VectorReshapeHelper;
/*
* @test
* @bug 8278623
* @modules jdk.incubator.vector
* @modules java.base/jdk.internal.misc
* @summary Test that vector cast intrinsics work as intended on avx512bw.
* @requires vm.cpu.features ~= ".*avx512bw.*"
* @library /test/lib /
* @run driver compiler.vectorapi.reshape.TestVectorCastAVX512BW
*/
public class TestVectorCastAVX512BW {
public static void main(String[] args) {
VectorReshapeHelper.runMainHelper(
TestVectorCast.class,
TestCastMethods.AVX512BW_CAST_TESTS.stream(),
"-XX:UseAVX=3");
}
}

View file

@ -74,8 +74,7 @@ public class TestVectorReinterpret {
.filter(ftype -> ftype != etype) .filter(ftype -> ftype != etype)
.map(ftype -> VectorSpeciesPair.makePair(VectorSpecies.of(etype, shape), .map(ftype -> VectorSpeciesPair.makePair(VectorSpecies.of(etype, shape),
VectorSpecies.of(ftype, shape))))) VectorSpecies.of(ftype, shape)))))
.filter(p -> p.isp().length() > 1 && p.osp().length() > 1), .filter(p -> p.isp().length() > 1 && p.osp().length() > 1)
"--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED"
); );
} }
} }

View file

@ -32,6 +32,7 @@ import static compiler.vectorapi.reshape.utils.VectorReshapeHelper.*;
/** /**
* As spot in 8259353. We need to do a shrink and an expand together to not accidentally * As spot in 8259353. We need to do a shrink and an expand together to not accidentally
* zero out elements in the physical registers that may not be zero in general cases. * zero out elements in the physical registers that may not be zero in general cases.
*
* In some methods, 2 consecutive ReinterpretNodes may be optimized out. * In some methods, 2 consecutive ReinterpretNodes may be optimized out.
*/ */
public class TestVectorDoubleExpandShrink { public class TestVectorDoubleExpandShrink {

View file

@ -32,6 +32,7 @@ import static compiler.vectorapi.reshape.utils.VectorReshapeHelper.*;
/** /**
* This class contains method to ensure a resizing reinterpretation operations work as * This class contains method to ensure a resizing reinterpretation operations work as
* intended. * intended.
*
* In each test, the ReinterpretNode is expected to appear exactly once. * In each test, the ReinterpretNode is expected to appear exactly once.
*/ */
public class TestVectorExpandShrink { public class TestVectorExpandShrink {

View file

@ -31,10 +31,12 @@ import static compiler.vectorapi.reshape.utils.VectorReshapeHelper.*;
/** /**
* This class contains methods to test for reinterpretation operations that reinterpret * This class contains methods to test for reinterpretation operations that reinterpret
* a vector as a similar vector with another element type. It is complicated to verify * a vector as a similar vector with another element type.
* the IR in this case since a load/store with respect to byte array will result in *
* additional ReinterpretNodes if the vector element type is not byte. As a result, * It is complicated to verify the IR in this case since a load/store with respect to
* arguments need to be arrays of the correct type. * byte array will result in additional ReinterpretNodes if the vector element type is
* not byte. As a result, arguments need to be arrays of the correct type.
*
* In each test, the ReinterpretNode is expected to appear exactly once. * In each test, the ReinterpretNode is expected to appear exactly once.
*/ */
public class TestVectorRebracket { public class TestVectorRebracket {

View file

@ -39,7 +39,7 @@ public class TestCastMethods {
makePair(BSPEC64, SSPEC128), makePair(BSPEC64, SSPEC128),
makePair(BSPEC64, ISPEC128), makePair(BSPEC64, ISPEC128),
makePair(BSPEC64, FSPEC128), makePair(BSPEC64, FSPEC128),
// makePair(BSPEC64, DSPEC256), // makePair(BSPEC64, DSPEC256),
makePair(SSPEC64, BSPEC64), makePair(SSPEC64, BSPEC64),
makePair(SSPEC128, BSPEC64), makePair(SSPEC128, BSPEC64),
makePair(SSPEC64, ISPEC64), makePair(SSPEC64, ISPEC64),
@ -48,7 +48,7 @@ public class TestCastMethods {
makePair(SSPEC64, FSPEC64), makePair(SSPEC64, FSPEC64),
makePair(SSPEC64, FSPEC128), makePair(SSPEC64, FSPEC128),
makePair(SSPEC64, DSPEC128), makePair(SSPEC64, DSPEC128),
// makePair(SSPEC64, DSPEC256), // makePair(SSPEC64, DSPEC256),
makePair(ISPEC128, BSPEC64), makePair(ISPEC128, BSPEC64),
makePair(ISPEC64, SSPEC64), makePair(ISPEC64, SSPEC64),
makePair(ISPEC128, SSPEC64), makePair(ISPEC128, SSPEC64),
@ -63,11 +63,11 @@ public class TestCastMethods {
makePair(FSPEC128, ISPEC128), makePair(FSPEC128, ISPEC128),
makePair(FSPEC64, DSPEC128), makePair(FSPEC64, DSPEC128),
makePair(FSPEC128, DSPEC256), makePair(FSPEC128, DSPEC256),
makePair(DSPEC128, FSPEC64) makePair(DSPEC128, FSPEC64),
// makePair(DSPEC256, FSPEC128) makePair(DSPEC256, FSPEC128)
); );
public static final List<VectorSpeciesPair> AVX2_CAST_TESTS = Stream.concat(AVX1_CAST_TESTS.stream(), List.of( public static final List<VectorSpeciesPair> AVX2_CAST_TESTS = Stream.concat(AVX1_CAST_TESTS.stream(), Stream.of(
makePair(BSPEC128, SSPEC256), makePair(BSPEC128, SSPEC256),
makePair(BSPEC64, ISPEC256), makePair(BSPEC64, ISPEC256),
makePair(BSPEC64, LSPEC256), makePair(BSPEC64, LSPEC256),
@ -84,15 +84,13 @@ public class TestCastMethods {
makePair(LSPEC256, SSPEC64), makePair(LSPEC256, SSPEC64),
makePair(LSPEC256, ISPEC128), makePair(LSPEC256, ISPEC128),
makePair(FSPEC256, ISPEC256) makePair(FSPEC256, ISPEC256)
).stream()).toList(); )).toList();
public static final List<VectorSpeciesPair> AVX512_CAST_TESTS = Stream.concat(AVX2_CAST_TESTS.stream(), List.of( public static final List<VectorSpeciesPair> AVX512_CAST_TESTS = Stream.concat(AVX2_CAST_TESTS.stream(), Stream.of(
makePair(BSPEC256, SSPEC512),
makePair(BSPEC128, ISPEC512), makePair(BSPEC128, ISPEC512),
makePair(BSPEC64, LSPEC512), makePair(BSPEC64, LSPEC512),
makePair(BSPEC128, FSPEC512), makePair(BSPEC128, FSPEC512),
makePair(BSPEC64, DSPEC512), makePair(BSPEC64, DSPEC512),
makePair(SSPEC512, BSPEC256),
makePair(SSPEC256, ISPEC512), makePair(SSPEC256, ISPEC512),
makePair(SSPEC128, LSPEC512), makePair(SSPEC128, LSPEC512),
makePair(SSPEC256, FSPEC512), makePair(SSPEC256, FSPEC512),
@ -108,16 +106,21 @@ public class TestCastMethods {
makePair(FSPEC512, ISPEC512), makePair(FSPEC512, ISPEC512),
makePair(FSPEC256, DSPEC512), makePair(FSPEC256, DSPEC512),
makePair(DSPEC512, FSPEC256) makePair(DSPEC512, FSPEC256)
).stream()).toList(); )).toList();
public static final List<VectorSpeciesPair> AVX512DQ_CAST_TESTS = Stream.concat(AVX512_CAST_TESTS.stream(), List.of( public static final List<VectorSpeciesPair> AVX512BW_CAST_TESTS = Stream.concat(AVX512_CAST_TESTS.stream(), Stream.of(
makePair(BSPEC256, SSPEC512),
makePair(SSPEC512, BSPEC256)
)).toList();
public static final List<VectorSpeciesPair> AVX512DQ_CAST_TESTS = Stream.concat(AVX512_CAST_TESTS.stream(), Stream.of(
makePair(LSPEC128, DSPEC128), makePair(LSPEC128, DSPEC128),
makePair(LSPEC256, DSPEC256), makePair(LSPEC256, DSPEC256),
makePair(LSPEC512, DSPEC512), makePair(LSPEC512, DSPEC512),
makePair(DSPEC128, LSPEC128), makePair(DSPEC128, LSPEC128),
makePair(DSPEC256, LSPEC256), makePair(DSPEC256, LSPEC256),
makePair(DSPEC512, LSPEC512) makePair(DSPEC512, LSPEC512)
).stream()).toList(); )).toList();
public static final List<VectorSpeciesPair> SVE_CAST_TESTS = List.of( public static final List<VectorSpeciesPair> SVE_CAST_TESTS = List.of(
makePair(BSPEC64, SSPEC128), makePair(BSPEC64, SSPEC128),

View file

@ -35,12 +35,53 @@ public class UnsafeUtils {
return UNSAFE.arrayBaseOffset(etype.arrayType()); return UNSAFE.arrayBaseOffset(etype.arrayType());
} }
public static int getByte(Object o, long base, int i) { public static byte getByte(Object o, long base, int i) {
// This is technically an UB, what we need is UNSAFE.getByteUnaligned but they seem to be equivalent // This technically leads to UB, what we need is UNSAFE.getByteUnaligned but they seem to be equivalent
return UNSAFE.getByte(o, base + i); return UNSAFE.getByte(o, base + (long)i * Unsafe.ARRAY_BYTE_INDEX_SCALE);
} }
public static void putByte(Object o, long base, int i, int value) { public static void putByte(Object o, long base, int i, byte value) {
UNSAFE.putByte(o, base + i, (byte)value); // This technically leads to UB, what we need is UNSAFE.putByteUnaligned but they seem to be equivalent
UNSAFE.putByte(o, base + (long)i * Unsafe.ARRAY_BYTE_INDEX_SCALE, value);
}
public static short getShort(Object o, long base, int i) {
return UNSAFE.getShort(o, base + (long)i * Unsafe.ARRAY_SHORT_INDEX_SCALE);
}
public static void putShort(Object o, long base, int i, short value) {
UNSAFE.putShort(o, base + (long)i * Unsafe.ARRAY_SHORT_INDEX_SCALE, value);
}
public static int getInt(Object o, long base, int i) {
return UNSAFE.getInt(o, base + (long)i * Unsafe.ARRAY_INT_INDEX_SCALE);
}
public static void putInt(Object o, long base, int i, int value) {
UNSAFE.putInt(o, base + (long)i * Unsafe.ARRAY_INT_INDEX_SCALE, value);
}
public static long getLong(Object o, long base, int i) {
return UNSAFE.getLong(o, base + (long)i * Unsafe.ARRAY_LONG_INDEX_SCALE);
}
public static void putLong(Object o, long base, int i, long value) {
UNSAFE.putLong(o, base + (long)i * Unsafe.ARRAY_LONG_INDEX_SCALE, value);
}
public static float getFloat(Object o, long base, int i) {
return UNSAFE.getFloat(o, base + (long)i * Unsafe.ARRAY_FLOAT_INDEX_SCALE);
}
public static void putFloat(Object o, long base, int i, float value) {
UNSAFE.putFloat(o, base + (long)i * Unsafe.ARRAY_FLOAT_INDEX_SCALE, value);
}
public static double getDouble(Object o, long base, int i) {
return UNSAFE.getDouble(o, base + (long)i * Unsafe.ARRAY_DOUBLE_INDEX_SCALE);
}
public static void putDouble(Object o, long base, int i, double value) {
UNSAFE.putDouble(o, base + (long)i * Unsafe.ARRAY_DOUBLE_INDEX_SCALE, value);
} }
} }

View file

@ -81,7 +81,7 @@ public class VectorReshapeHelper {
var test = new TestFramework(testClass); var test = new TestFramework(testClass);
test.setDefaultWarmup(1); test.setDefaultWarmup(1);
test.addHelperClasses(VectorReshapeHelper.class); test.addHelperClasses(VectorReshapeHelper.class);
test.addFlags("--add-modules=jdk.incubator.vector"); test.addFlags("--add-modules=jdk.incubator.vector", "--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED");
test.addFlags(flags); test.addFlags(flags);
String testMethodNames = testMethods String testMethodNames = testMethods
.filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length()) .filter(p -> p.isp().length() <= VectorSpecies.ofLargestShape(p.isp().elementType()).length())
@ -94,10 +94,10 @@ public class VectorReshapeHelper {
@ForceInline @ForceInline
public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop, public static <T, U> void vectorCast(VectorOperators.Conversion<T, U> cop,
VectorSpecies<T> isp, VectorSpecies<U> osp, byte[] input, byte[] output) { VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
isp.fromByteArray(input, 0, ByteOrder.nativeOrder()) var outputVector = readVector(isp, input)
.convertShape(cop, osp, 0) .convertShape(cop, osp, 0);
.intoByteArray(output, 0, ByteOrder.nativeOrder()); writeVector(osp, outputVector, output);
} }
public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp, public static <T, U> void runCastHelper(VectorOperators.Conversion<T, U> castOp,
@ -108,31 +108,34 @@ public class VectorReshapeHelper {
var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
var testMethod = MethodHandles.lookup().findStatic(caller, var testMethod = MethodHandles.lookup().findStatic(caller,
testMethodName, testMethodName,
MethodType.methodType(void.class, byte.class.arrayType(), byte.class.arrayType())); MethodType.methodType(void.class, isp.elementType().arrayType(), osp.elementType().arrayType()))
byte[] input = new byte[isp.vectorByteSize()]; .asType(MethodType.methodType(void.class, Object.class, Object.class));
byte[] output = new byte[osp.vectorByteSize()]; Object input = Array.newInstance(isp.elementType(), isp.length());
Object output = Array.newInstance(osp.elementType(), osp.length());
long ibase = UnsafeUtils.arrayBase(isp.elementType());
long obase = UnsafeUtils.arrayBase(osp.elementType());
for (int iter = 0; iter < INVOCATIONS; iter++) { for (int iter = 0; iter < INVOCATIONS; iter++) {
// We need to generate arrays with NaN or very large values occasionally // We need to generate arrays with NaN or very large values occasionally
boolean normalArray = random.nextBoolean(); boolean normalArray = random.nextBoolean();
var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30); var abnormalValue = List.of(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, -1e30, 1e30);
for (int i = 0; i < isp.length(); i++) { for (int i = 0; i < isp.length(); i++) {
switch (isp.elementType().getName()) { switch (isp.elementType().getName()) {
case "byte" -> setByte(input, i, (byte)random.nextInt()); case "byte" -> UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
case "short" -> setShort(input, i, (short)random.nextInt()); case "short" -> UnsafeUtils.putShort(input, ibase, i, (short)random.nextInt());
case "int" -> setInt(input, i, random.nextInt()); case "int" -> UnsafeUtils.putInt(input, ibase, i, random.nextInt());
case "long" -> setLong(input, i, random.nextLong()); case "long" -> UnsafeUtils.putLong(input, ibase, i, random.nextLong());
case "float" -> { case "float" -> {
if (normalArray || random.nextBoolean()) { if (normalArray || random.nextBoolean()) {
setFloat(input, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE)); UnsafeUtils.putFloat(input, ibase, i, random.nextFloat(Byte.MIN_VALUE, Byte.MAX_VALUE));
} else { } else {
setFloat(input, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue()); UnsafeUtils.putFloat(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())).floatValue());
} }
} }
case "double" -> { case "double" -> {
if (normalArray || random.nextBoolean()) { if (normalArray || random.nextBoolean()) {
setDouble(input, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE)); UnsafeUtils.putDouble(input, ibase, i, random.nextDouble(Byte.MIN_VALUE, Byte.MAX_VALUE));
} else { } else {
setDouble(input, i, abnormalValue.get(random.nextInt(abnormalValue.size()))); UnsafeUtils.putDouble(input, ibase, i, abnormalValue.get(random.nextInt(abnormalValue.size())));
} }
} }
default -> throw new AssertionError(); default -> throw new AssertionError();
@ -145,12 +148,12 @@ public class VectorReshapeHelper {
Number expected, actual; Number expected, actual;
if (i < isp.length()) { if (i < isp.length()) {
Number initial = switch (isp.elementType().getName()) { Number initial = switch (isp.elementType().getName()) {
case "byte" -> getByte(input, i); case "byte" -> UnsafeUtils.getByte(input, ibase, i);
case "short" -> getShort(input, i); case "short" -> UnsafeUtils.getShort(input, ibase, i);
case "int" -> getInt(input, i); case "int" -> UnsafeUtils.getInt(input, ibase, i);
case "long" -> getLong(input, i); case "long" -> UnsafeUtils.getLong(input, ibase, i);
case "float" -> getFloat(input, i); case "float" -> UnsafeUtils.getFloat(input, ibase, i);
case "double" -> getDouble(input, i); case "double" -> UnsafeUtils.getDouble(input, ibase, i);
default -> throw new AssertionError(); default -> throw new AssertionError();
}; };
expected = switch (osp.elementType().getName()) { expected = switch (osp.elementType().getName()) {
@ -192,12 +195,12 @@ public class VectorReshapeHelper {
}; };
} }
actual = switch (osp.elementType().getName()) { actual = switch (osp.elementType().getName()) {
case "byte" -> getByte(output, i); case "byte" -> UnsafeUtils.getByte(output, obase, i);
case "short" -> getShort(output, i); case "short" -> UnsafeUtils.getShort(output, obase, i);
case "int" -> getInt(output, i); case "int" -> UnsafeUtils.getInt(output, obase, i);
case "long" -> getLong(output, i); case "long" -> UnsafeUtils.getLong(output, obase, i);
case "float" -> getFloat(output, i); case "float" -> UnsafeUtils.getFloat(output, obase, i);
case "double" -> getDouble(output, i); case "double" -> UnsafeUtils.getDouble(output, obase, i);
default -> throw new AssertionError(); default -> throw new AssertionError();
}; };
Asserts.assertEquals(expected, actual); Asserts.assertEquals(expected, actual);
@ -206,13 +209,13 @@ public class VectorReshapeHelper {
} }
@ForceInline @ForceInline
public static <T, U> void vectorExpandShrink(VectorSpecies<T> isp, VectorSpecies<U> osp, byte[] input, byte[] output) { public static void vectorExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) {
isp.fromByteArray(input, 0, ByteOrder.nativeOrder()) isp.fromByteArray(input, 0, ByteOrder.nativeOrder())
.reinterpretShape(osp, 0) .reinterpretShape(osp, 0)
.intoByteArray(output, 0, ByteOrder.nativeOrder()); .intoByteArray(output, 0, ByteOrder.nativeOrder());
} }
public static <T, U> void runExpandShrinkHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { public static void runExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
var random = RandomGenerator.getDefault(); var random = RandomGenerator.getDefault();
String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
@ -235,14 +238,14 @@ public class VectorReshapeHelper {
} }
@ForceInline @ForceInline
public static <T, U> void vectorDoubleExpandShrink(VectorSpecies<T> isp, VectorSpecies<U> osp, byte[] input, byte[] output) { public static void vectorDoubleExpandShrink(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp, byte[] input, byte[] output) {
isp.fromByteArray(input, 0, ByteOrder.nativeOrder()) isp.fromByteArray(input, 0, ByteOrder.nativeOrder())
.reinterpretShape(osp, 0) .reinterpretShape(osp, 0)
.reinterpretShape(isp, 0) .reinterpretShape(isp, 0)
.intoByteArray(output, 0, ByteOrder.nativeOrder()); .intoByteArray(output, 0, ByteOrder.nativeOrder());
} }
public static <T, U> void runDoubleExpandShrinkHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { public static void runDoubleExpandShrinkHelper(VectorSpecies<Byte> isp, VectorSpecies<Byte> osp) throws Throwable {
var random = RandomGenerator.getDefault(); var random = RandomGenerator.getDefault();
String testMethodName = VectorSpeciesPair.makePair(isp, osp).format(); String testMethodName = VectorSpeciesPair.makePair(isp, osp).format();
var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass(); var caller = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE).getCallerClass();
@ -264,28 +267,11 @@ public class VectorReshapeHelper {
} }
} }
// All this complication is due to the fact that vector load and store with respect to byte array introduce
// additional ReinterpretNodes, several ReinterpretNodes back to back being optimized make the number of
// nodes remaining in the IR becomes unpredictable.
@ForceInline @ForceInline
public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) { public static <T, U> void vectorRebracket(VectorSpecies<T> isp, VectorSpecies<U> osp, Object input, Object output) {
var outputVector = isp.fromArray(input, 0).reinterpretShape(osp, 0); var outputVector = readVector(isp, input)
var otype = osp.elementType(); .reinterpretShape(osp, 0);
if (otype == byte.class) { writeVector(osp, outputVector, output);
((ByteVector)outputVector).intoArray((byte[])output, 0);
} else if (otype == short.class) {
((ShortVector)outputVector).intoArray((short[])output, 0);
} else if (otype == int.class) {
((IntVector)outputVector).intoArray((int[])output, 0);
} else if (otype == long.class) {
((LongVector)outputVector).intoArray((long[])output, 0);
} else if (otype == float.class) {
((FloatVector)outputVector).intoArray((float[])output, 0);
} else if (otype == double.class) {
((DoubleVector)outputVector).intoArray((double[])output, 0);
} else {
throw new AssertionError();
}
} }
public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable { public static <T, U> void runRebracketHelper(VectorSpecies<T> isp, VectorSpecies<U> osp) throws Throwable {
@ -302,7 +288,7 @@ public class VectorReshapeHelper {
long obase = UnsafeUtils.arrayBase(osp.elementType()); long obase = UnsafeUtils.arrayBase(osp.elementType());
for (int iter = 0; iter < INVOCATIONS; iter++) { for (int iter = 0; iter < INVOCATIONS; iter++) {
for (int i = 0; i < isp.vectorByteSize(); i++) { for (int i = 0; i < isp.vectorByteSize(); i++) {
UnsafeUtils.putByte(input, ibase, i, random.nextInt()); UnsafeUtils.putByte(input, ibase, i, (byte)random.nextInt());
} }
testMethod.invokeExact(input, output); testMethod.invokeExact(input, output);
@ -315,58 +301,28 @@ public class VectorReshapeHelper {
} }
} }
public static byte getByte(byte[] array, int index) { @ForceInline
return (byte)BYTE_ACCESS.get(array, index * Byte.BYTES); private static <T> Vector<T> readVector(VectorSpecies<T> isp, Object input) {
return isp.fromArray(input, 0);
} }
public static short getShort(byte[] array, int index) { @ForceInline
return (short)SHORT_ACCESS.get(array, index * Short.BYTES); private static <U> void writeVector(VectorSpecies<U> osp, Vector<U> vector, Object output) {
var otype = osp.elementType();
if (otype == byte.class) {
((ByteVector)vector).intoArray((byte[])output, 0);
} else if (otype == short.class) {
((ShortVector)vector).intoArray((short[])output, 0);
} else if (otype == int.class) {
((IntVector)vector).intoArray((int[])output, 0);
} else if (otype == long.class) {
((LongVector)vector).intoArray((long[])output, 0);
} else if (otype == float.class) {
((FloatVector)vector).intoArray((float[])output, 0);
} else if (otype == double.class) {
((DoubleVector)vector).intoArray((double[])output, 0);
} else {
throw new AssertionError();
} }
public static int getInt(byte[] array, int index) {
return (int)INT_ACCESS.get(array, index * Integer.BYTES);
} }
public static long getLong(byte[] array, int index) {
return (long)LONG_ACCESS.get(array, index * Long.BYTES);
}
public static float getFloat(byte[] array, int index) {
return (float)FLOAT_ACCESS.get(array, index * Float.BYTES);
}
public static double getDouble(byte[] array, int index) {
return (double)DOUBLE_ACCESS.get(array, index * Double.BYTES);
}
public static void setByte(byte[] array, int index, byte value) {
BYTE_ACCESS.set(array, index * Byte.BYTES, value);
}
public static void setShort(byte[] array, int index, short value) {
SHORT_ACCESS.set(array, index * Short.BYTES, value);
}
public static void setInt(byte[] array, int index, int value) {
INT_ACCESS.set(array, index * Integer.BYTES, value);
}
public static void setLong(byte[] array, int index, long value) {
LONG_ACCESS.set(array, index * Long.BYTES, value);
}
public static void setFloat(byte[] array, int index, float value) {
FLOAT_ACCESS.set(array, index * Float.BYTES, value);
}
public static void setDouble(byte[] array, int index, double value) {
DOUBLE_ACCESS.set(array, index * Double.BYTES, value);
}
private static final VarHandle BYTE_ACCESS = MethodHandles.arrayElementVarHandle(byte.class.arrayType());
private static final VarHandle SHORT_ACCESS = MethodHandles.byteArrayViewVarHandle(short.class.arrayType(), ByteOrder.nativeOrder());
private static final VarHandle INT_ACCESS = MethodHandles.byteArrayViewVarHandle(int.class.arrayType(), ByteOrder.nativeOrder());
private static final VarHandle LONG_ACCESS = MethodHandles.byteArrayViewVarHandle(long.class.arrayType(), ByteOrder.nativeOrder());
private static final VarHandle FLOAT_ACCESS = MethodHandles.byteArrayViewVarHandle(float.class.arrayType(), ByteOrder.nativeOrder());
private static final VarHandle DOUBLE_ACCESS = MethodHandles.byteArrayViewVarHandle(double.class.arrayType(), ByteOrder.nativeOrder());
} }