mirror of
https://github.com/openjdk/jdk.git
synced 2025-09-17 01:24:33 +02:00
8294588: Auto vectorize half precision floating point conversion APIs
Reviewed-by: sviswanathan, kvn, jbhateja, fgao, xgong
This commit is contained in:
parent
46cd457b0f
commit
073897c88b
12 changed files with 231 additions and 10 deletions
|
@ -1931,14 +1931,14 @@ void Assembler::vcvtdq2pd(XMMRegister dst, XMMRegister src, int vector_len) {
|
|||
}
|
||||
|
||||
void Assembler::vcvtps2ph(XMMRegister dst, XMMRegister src, int imm8, int vector_len) {
|
||||
assert(VM_Version::supports_avx512vl() || VM_Version::supports_f16c(), "");
|
||||
assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), "");
|
||||
InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /*uses_vl */ true);
|
||||
int encode = vex_prefix_and_encode(src->encoding(), 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_3A, &attributes);
|
||||
emit_int24(0x1D, (0xC0 | encode), imm8);
|
||||
}
|
||||
|
||||
void Assembler::evcvtps2ph(Address dst, KRegister mask, XMMRegister src, int imm8, int vector_len) {
|
||||
assert(VM_Version::supports_avx512vl(), "");
|
||||
assert(VM_Version::supports_evex(), "");
|
||||
InstructionMark im(this);
|
||||
InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /*uses_vl */ true);
|
||||
attributes.set_address_attributes(/* tuple_type */ EVEX_HVM, /* input_size_in_bits */ EVEX_64bit);
|
||||
|
@ -1951,13 +1951,34 @@ void Assembler::evcvtps2ph(Address dst, KRegister mask, XMMRegister src, int imm
|
|||
emit_int8(imm8);
|
||||
}
|
||||
|
||||
void Assembler::vcvtps2ph(Address dst, XMMRegister src, int imm8, int vector_len) {
|
||||
assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), "");
|
||||
InstructionMark im(this);
|
||||
InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /*uses_vl */ true);
|
||||
attributes.set_address_attributes(/* tuple_type */ EVEX_HVM, /* input_size_in_bits */ EVEX_NObit);
|
||||
vex_prefix(dst, 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_3A, &attributes);
|
||||
emit_int8(0x1D);
|
||||
emit_operand(src, dst, 1);
|
||||
emit_int8(imm8);
|
||||
}
|
||||
|
||||
void Assembler::vcvtph2ps(XMMRegister dst, XMMRegister src, int vector_len) {
|
||||
assert(VM_Version::supports_avx512vl() || VM_Version::supports_f16c(), "");
|
||||
assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), "");
|
||||
InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */false, /* no_mask_reg */ true, /* uses_vl */ true);
|
||||
int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_38, &attributes);
|
||||
emit_int16(0x13, (0xC0 | encode));
|
||||
}
|
||||
|
||||
void Assembler::vcvtph2ps(XMMRegister dst, Address src, int vector_len) {
|
||||
assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), "");
|
||||
InstructionMark im(this);
|
||||
InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /*uses_vl */ true);
|
||||
attributes.set_address_attributes(/* tuple_type */ EVEX_HVM, /* input_size_in_bits */ EVEX_NObit);
|
||||
vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_38, &attributes);
|
||||
emit_int8(0x13);
|
||||
emit_operand(dst, src, 0);
|
||||
}
|
||||
|
||||
void Assembler::cvtdq2ps(XMMRegister dst, XMMRegister src) {
|
||||
NOT_LP64(assert(VM_Version::supports_sse2(), ""));
|
||||
InstructionAttr attributes(AVX_128bit, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true);
|
||||
|
|
|
@ -1160,6 +1160,8 @@ private:
|
|||
void vcvtps2ph(XMMRegister dst, XMMRegister src, int imm8, int vector_len);
|
||||
void vcvtph2ps(XMMRegister dst, XMMRegister src, int vector_len);
|
||||
void evcvtps2ph(Address dst, KRegister mask, XMMRegister src, int imm8, int vector_len);
|
||||
void vcvtps2ph(Address dst, XMMRegister src, int imm8, int vector_len);
|
||||
void vcvtph2ps(XMMRegister dst, Address src, int vector_len);
|
||||
|
||||
// Convert Packed Signed Doubleword Integers to Packed Single-Precision Floating-Point Value
|
||||
void cvtdq2ps(XMMRegister dst, XMMRegister src);
|
||||
|
|
|
@ -956,6 +956,7 @@ void VM_Version::get_processor_features() {
|
|||
if (UseAVX < 1) {
|
||||
_features &= ~CPU_AVX;
|
||||
_features &= ~CPU_VZEROUPPER;
|
||||
_features &= ~CPU_F16C;
|
||||
}
|
||||
|
||||
if (logical_processors_per_package() == 1) {
|
||||
|
|
|
@ -1687,6 +1687,12 @@ const bool Matcher::match_rule_supported(int opcode) {
|
|||
return false;
|
||||
}
|
||||
break;
|
||||
case Op_VectorCastF2HF:
|
||||
case Op_VectorCastHF2F:
|
||||
if (!VM_Version::supports_f16c() && !VM_Version::supports_evex()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return true; // Match rules are supported by default.
|
||||
}
|
||||
|
@ -1901,6 +1907,14 @@ const bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType
|
|||
return false;
|
||||
}
|
||||
break;
|
||||
case Op_VectorCastF2HF:
|
||||
case Op_VectorCastHF2F:
|
||||
if (!VM_Version::supports_f16c() &&
|
||||
((!VM_Version::supports_evex() ||
|
||||
((size_in_bits != 512) && !VM_Version::supports_avx512vl())))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Op_RoundVD:
|
||||
if (!VM_Version::supports_avx512dq()) {
|
||||
return false;
|
||||
|
@ -3673,6 +3687,26 @@ instruct convF2HF_mem_reg(memory mem, regF src, kReg ktmp, rRegI rtmp) %{
|
|||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
instruct vconvF2HF(vec dst, vec src) %{
|
||||
match(Set dst (VectorCastF2HF src));
|
||||
format %{ "vector_conv_F2HF $dst $src" %}
|
||||
ins_encode %{
|
||||
int vlen_enc = vector_length_encoding(this, $src);
|
||||
__ vcvtps2ph($dst$$XMMRegister, $src$$XMMRegister, 0x04, vlen_enc);
|
||||
%}
|
||||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
instruct vconvF2HF_mem_reg(memory mem, vec src) %{
|
||||
match(Set mem (StoreVector mem (VectorCastF2HF src)));
|
||||
format %{ "vcvtps2ph $mem,$src" %}
|
||||
ins_encode %{
|
||||
int vlen_enc = vector_length_encoding(this, $src);
|
||||
__ vcvtps2ph($mem$$Address, $src$$XMMRegister, 0x04, vlen_enc);
|
||||
%}
|
||||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
instruct convHF2F_reg_reg(regF dst, rRegI src) %{
|
||||
match(Set dst (ConvHF2F src));
|
||||
format %{ "vcvtph2ps $dst,$src" %}
|
||||
|
@ -3683,6 +3717,27 @@ instruct convHF2F_reg_reg(regF dst, rRegI src) %{
|
|||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
instruct vconvHF2F_reg_mem(vec dst, memory mem) %{
|
||||
match(Set dst (VectorCastHF2F (LoadVector mem)));
|
||||
format %{ "vcvtph2ps $dst,$mem" %}
|
||||
ins_encode %{
|
||||
int vlen_enc = vector_length_encoding(this);
|
||||
__ vcvtph2ps($dst$$XMMRegister, $mem$$Address, vlen_enc);
|
||||
%}
|
||||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
instruct vconvHF2F(vec dst, vec src) %{
|
||||
match(Set dst (VectorCastHF2F src));
|
||||
ins_cost(125);
|
||||
format %{ "vector_conv_HF2F $dst,$src" %}
|
||||
ins_encode %{
|
||||
int vlen_enc = vector_length_encoding(this);
|
||||
__ vcvtph2ps($dst$$XMMRegister, $src$$XMMRegister, vlen_enc);
|
||||
%}
|
||||
ins_pipe( pipe_slow );
|
||||
%}
|
||||
|
||||
// ---------------------------------------- VectorReinterpret ------------------------------------
|
||||
instruct reinterpret_mask(kReg dst) %{
|
||||
predicate(n->bottom_type()->isa_vectmask() &&
|
||||
|
|
|
@ -4223,7 +4223,7 @@ bool MatchRule::is_vector() const {
|
|||
"VectorTest", "VectorLoadMask", "VectorStoreMask", "VectorBlend", "VectorInsert",
|
||||
"VectorRearrange","VectorLoadShuffle", "VectorLoadConst",
|
||||
"VectorCastB2X", "VectorCastS2X", "VectorCastI2X",
|
||||
"VectorCastL2X", "VectorCastF2X", "VectorCastD2X",
|
||||
"VectorCastL2X", "VectorCastF2X", "VectorCastD2X", "VectorCastF2HF", "VectorCastHF2F",
|
||||
"VectorUCastB2X", "VectorUCastS2X", "VectorUCastI2X",
|
||||
"VectorMaskWrapper","VectorMaskCmp","VectorReinterpret","LoadVectorMasked","StoreVectorMasked",
|
||||
"FmaVD","FmaVF","PopCountVI","PopCountVL","PopulateIndex","VectorLongToMask",
|
||||
|
|
|
@ -506,6 +506,8 @@ macro(VectorCastI2X)
|
|||
macro(VectorCastL2X)
|
||||
macro(VectorCastF2X)
|
||||
macro(VectorCastD2X)
|
||||
macro(VectorCastF2HF)
|
||||
macro(VectorCastHF2F)
|
||||
macro(VectorUCastB2X)
|
||||
macro(VectorUCastS2X)
|
||||
macro(VectorUCastI2X)
|
||||
|
|
|
@ -2712,7 +2712,7 @@ bool SuperWord::output() {
|
|||
assert(n->req() == 2, "only one input expected");
|
||||
BasicType bt = velt_basic_type(n);
|
||||
Node* in = vector_opd(p, 1);
|
||||
int vopc = VectorCastNode::opcode(in->bottom_type()->is_vect()->element_basic_type());
|
||||
int vopc = VectorCastNode::opcode(opc, in->bottom_type()->is_vect()->element_basic_type());
|
||||
vn = VectorCastNode::make(vopc, in, bt, vlen);
|
||||
vlen_in_bytes = vn->as_Vector()->length_in_bytes();
|
||||
} else if (is_cmov_pack(p)) {
|
||||
|
|
|
@ -775,7 +775,7 @@ bool LibraryCallKit::inline_vector_shuffle_to_vector() {
|
|||
return false;
|
||||
}
|
||||
|
||||
int cast_vopc = VectorCastNode::opcode(T_BYTE); // from shuffle of type T_BYTE
|
||||
int cast_vopc = VectorCastNode::opcode(-1, T_BYTE); // from shuffle of type T_BYTE
|
||||
// Make sure that cast is implemented to particular type/size combination.
|
||||
if (!arch_supports_vector(cast_vopc, num_elem, elem_bt, VecMaskNotUsed)) {
|
||||
if (C->print_intrinsics()) {
|
||||
|
@ -2489,7 +2489,7 @@ bool LibraryCallKit::inline_vector_convert() {
|
|||
Node* op = opd1;
|
||||
if (is_cast) {
|
||||
assert(!is_mask || num_elem_from == num_elem_to, "vector mask cast needs the same elem num");
|
||||
int cast_vopc = VectorCastNode::opcode(elem_bt_from, !is_ucast);
|
||||
int cast_vopc = VectorCastNode::opcode(-1, elem_bt_from, !is_ucast);
|
||||
|
||||
// Make sure that vector cast is implemented to particular type/size combination if it is
|
||||
// not a mask casting.
|
||||
|
|
|
@ -467,6 +467,8 @@ bool VectorNode::is_convert_opcode(int opc) {
|
|||
case Op_ConvD2F:
|
||||
case Op_ConvF2D:
|
||||
case Op_ConvD2I:
|
||||
case Op_ConvF2HF:
|
||||
case Op_ConvHF2F:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -1328,14 +1330,31 @@ VectorCastNode* VectorCastNode::make(int vopc, Node* n1, BasicType bt, uint vlen
|
|||
case Op_VectorUCastB2X: return new VectorUCastB2XNode(n1, vt);
|
||||
case Op_VectorUCastS2X: return new VectorUCastS2XNode(n1, vt);
|
||||
case Op_VectorUCastI2X: return new VectorUCastI2XNode(n1, vt);
|
||||
case Op_VectorCastHF2F: return new VectorCastHF2FNode(n1, vt);
|
||||
case Op_VectorCastF2HF: return new VectorCastF2HFNode(n1, vt);
|
||||
default:
|
||||
assert(false, "unknown node: %s", NodeClassNames[vopc]);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
int VectorCastNode::opcode(BasicType bt, bool is_signed) {
|
||||
int VectorCastNode::opcode(int sopc, BasicType bt, bool is_signed) {
|
||||
assert((is_integral_type(bt) && bt != T_LONG) || is_signed, "");
|
||||
|
||||
// Handle special case for to/from Half Float conversions
|
||||
switch (sopc) {
|
||||
case Op_ConvHF2F:
|
||||
assert(bt == T_SHORT, "");
|
||||
return Op_VectorCastHF2F;
|
||||
case Op_ConvF2HF:
|
||||
assert(bt == T_FLOAT, "");
|
||||
return Op_VectorCastF2HF;
|
||||
default:
|
||||
// Handled normally below
|
||||
break;
|
||||
}
|
||||
|
||||
// Handle normal conversions
|
||||
switch (bt) {
|
||||
case T_BYTE: return is_signed ? Op_VectorCastB2X : Op_VectorUCastB2X;
|
||||
case T_SHORT: return is_signed ? Op_VectorCastS2X : Op_VectorUCastS2X;
|
||||
|
@ -1354,7 +1373,7 @@ bool VectorCastNode::implemented(int opc, uint vlen, BasicType src_type, BasicTy
|
|||
is_java_primitive(src_type) &&
|
||||
(vlen > 1) && is_power_of_2(vlen) &&
|
||||
VectorNode::vector_size_supported(dst_type, vlen)) {
|
||||
int vopc = VectorCastNode::opcode(src_type);
|
||||
int vopc = VectorCastNode::opcode(opc, src_type);
|
||||
return vopc > 0 && Matcher::match_rule_supported_superword(vopc, vlen, dst_type);
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -1542,7 +1542,7 @@ class VectorCastNode : public VectorNode {
|
|||
virtual int Opcode() const;
|
||||
|
||||
static VectorCastNode* make(int vopc, Node* n1, BasicType bt, uint vlen);
|
||||
static int opcode(BasicType bt, bool is_signed = true);
|
||||
static int opcode(int opc, BasicType bt, bool is_signed = true);
|
||||
static bool implemented(int opc, uint vlen, BasicType src_type, BasicType dst_type);
|
||||
|
||||
virtual Node* Identity(PhaseGVN* phase);
|
||||
|
@ -1628,6 +1628,22 @@ class VectorUCastS2XNode : public VectorCastNode {
|
|||
virtual int Opcode() const;
|
||||
};
|
||||
|
||||
class VectorCastHF2FNode : public VectorCastNode {
|
||||
public:
|
||||
VectorCastHF2FNode(Node* in, const TypeVect* vt) : VectorCastNode(in, vt) {
|
||||
assert(in->bottom_type()->is_vect()->element_basic_type() == T_SHORT, "must be short");
|
||||
}
|
||||
virtual int Opcode() const;
|
||||
};
|
||||
|
||||
class VectorCastF2HFNode : public VectorCastNode {
|
||||
public:
|
||||
VectorCastF2HFNode(Node* in, const TypeVect* vt) : VectorCastNode(in, vt) {
|
||||
assert(in->bottom_type()->is_vect()->element_basic_type() == T_FLOAT, "must be float");
|
||||
}
|
||||
virtual int Opcode() const;
|
||||
};
|
||||
|
||||
class VectorUCastI2XNode : public VectorCastNode {
|
||||
public:
|
||||
VectorUCastI2XNode(Node* in, const TypeVect* vt) : VectorCastNode(in, vt) {
|
||||
|
|
|
@ -1079,6 +1079,16 @@ public class IRNode {
|
|||
beforeMatchingNameRegex(VECTOR_CAST_S2X, "VectorCastS2X");
|
||||
}
|
||||
|
||||
public static final String VECTOR_CAST_F2HF = PREFIX + "VECTOR_CAST_F2HF" + POSTFIX;
|
||||
static {
|
||||
beforeMatchingNameRegex(VECTOR_CAST_F2HF, "VectorCastF2HF");
|
||||
}
|
||||
|
||||
public static final String VECTOR_CAST_HF2F = PREFIX + "VECTOR_CAST_HF2F" + POSTFIX;
|
||||
static {
|
||||
beforeMatchingNameRegex(VECTOR_CAST_HF2F, "VectorCastHF2F");
|
||||
}
|
||||
|
||||
public static final String VECTOR_MASK_CAST = PREFIX + "VECTOR_MASK_CAST" + POSTFIX;
|
||||
static {
|
||||
beforeMatchingNameRegex(VECTOR_MASK_CAST, "VectorMaskCast");
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
|
||||
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
|
||||
*
|
||||
* This code is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License version 2 only, as
|
||||
* published by the Free Software Foundation.
|
||||
*
|
||||
* This code is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
|
||||
* version 2 for more details (a copy is included in the LICENSE file that
|
||||
* accompanied this code).
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License version
|
||||
* 2 along with this work; if not, write to the Free Software Foundation,
|
||||
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
*
|
||||
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
|
||||
* or visit www.oracle.com if you need additional information or have any
|
||||
* questions.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @test
|
||||
* @bug 8294588
|
||||
* @summary Auto-vectorize Float.floatToFloat16, Float.float16ToFloat APIs
|
||||
* @requires vm.compiler2.enabled
|
||||
* @requires os.simpleArch == "x64"
|
||||
* @library /test/lib /
|
||||
* @run driver compiler.vectorization.TestFloatConversionsVector
|
||||
*/
|
||||
|
||||
package compiler.vectorization;
|
||||
|
||||
import compiler.lib.ir_framework.*;
|
||||
|
||||
public class TestFloatConversionsVector {
|
||||
private static final int ARRLEN = 1024;
|
||||
private static final int ITERS = 11000;
|
||||
private static float [] finp;
|
||||
private static short [] sout;
|
||||
private static short [] sinp;
|
||||
private static float [] fout;
|
||||
|
||||
public static void main(String args[]) {
|
||||
TestFramework.runWithFlags("-XX:-TieredCompilation",
|
||||
"-XX:CompileThresholdScaling=0.3");
|
||||
System.out.println("PASSED");
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = {IRNode.VECTOR_CAST_F2HF, "> 0"}, applyIfCPUFeatureOr = {"avx512f", "true", "f16c", "true"})
|
||||
public void test_float_float16(short[] sout, float[] finp) {
|
||||
for (int i = 0; i < finp.length; i++) {
|
||||
sout[i] = Float.floatToFloat16(finp[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Run(test = {"test_float_float16"}, mode = RunMode.STANDALONE)
|
||||
public void kernel_test_float_float16() {
|
||||
finp = new float[ARRLEN];
|
||||
sout = new short[ARRLEN];
|
||||
|
||||
for (int i = 0; i < ARRLEN; i++) {
|
||||
finp[i] = (float) i * 1.4f;
|
||||
}
|
||||
|
||||
for (int i = 0; i < ITERS; i++) {
|
||||
test_float_float16(sout, finp);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@IR(counts = {IRNode.VECTOR_CAST_HF2F, "> 0"}, applyIfCPUFeatureOr = {"avx512f", "true", "f16c", "true"})
|
||||
public void test_float16_float(float[] fout, short[] sinp) {
|
||||
for (int i = 0; i < sinp.length; i++) {
|
||||
fout[i] = Float.float16ToFloat(sinp[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Run(test = {"test_float16_float"}, mode = RunMode.STANDALONE)
|
||||
public void kernel_test_float16_float() {
|
||||
sinp = new short[ARRLEN];
|
||||
fout = new float[ARRLEN];
|
||||
|
||||
for (int i = 0; i < ARRLEN; i++) {
|
||||
sinp[i] = (short)i;
|
||||
}
|
||||
|
||||
for (int i = 0; i < ITERS; i++) {
|
||||
test_float16_float(fout , sinp);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue