diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index 828d8cfda91..085ae4a6ddd 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -3475,6 +3475,22 @@ void Assembler::vmovdqu(XMMRegister dst, XMMRegister src) { emit_int16(0x6F, (0xC0 | encode)); } +void Assembler::vmovw(XMMRegister dst, Register src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP5, &attributes, true); + emit_int16(0x6E, (0xC0 | encode)); +} + +void Assembler::vmovw(Register dst, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(src->encoding(), 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP5, &attributes, true); + emit_int16(0x7E, (0xC0 | encode)); +} + void Assembler::vmovdqu(XMMRegister dst, Address src) { assert(UseAVX > 0, ""); InstructionMark im(this); @@ -8442,6 +8458,70 @@ void Assembler::vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector emit_operand(dst, src, 0); } +void Assembler::vaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x58, (0xC0 | encode)); +} + +void Assembler::vsubsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5C, (0xC0 | encode)); +} + +void Assembler::vdivsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5E, (0xC0 | encode)); +} + +void Assembler::vmulsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x59, (0xC0 | encode)); +} + +void Assembler::vmaxsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5F, (0xC0 | encode)); +} + +void Assembler::vminsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5D, (0xC0 | encode)); +} + +void Assembler::vsqrtsh(XMMRegister dst, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x51, (0xC0 | encode)); +} + +void Assembler::vfmadd132sh(XMMRegister dst, XMMRegister src1, XMMRegister src2) { + assert(VM_Version::supports_avx512_fp16(), ""); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), src1->encoding(), src2->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP6, &attributes); + emit_int16((unsigned char)0x99, (0xC0 | encode)); +} + void Assembler::vpaddsb(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { assert(UseAVX > 0 && (vector_len == Assembler::AVX_512bit || (!needs_evex(dst, nds, src) || VM_Version::supports_avx512vl())), ""); assert(!needs_evex(dst, nds, src) || VM_Version::supports_avx512bw(), ""); diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 25be0d6a48d..1eb12fb93f0 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -585,6 +585,8 @@ class Assembler : public AbstractAssembler { VEX_OPCODE_0F_38 = 0x2, VEX_OPCODE_0F_3A = 0x3, VEX_OPCODE_0F_3C = 0x4, + VEX_OPCODE_MAP5 = 0x5, + VEX_OPCODE_MAP6 = 0x6, VEX_OPCODE_MASK = 0x1F }; @@ -1815,6 +1817,9 @@ private: void movsbl(Register dst, Address src); void movsbl(Register dst, Register src); + void vmovw(XMMRegister dst, Register src); + void vmovw(Register dst, XMMRegister src); + #ifdef _LP64 void movsbq(Register dst, Address src); void movsbq(Register dst, Register src); @@ -2691,6 +2696,16 @@ private: void vpaddd(XMMRegister dst, XMMRegister nds, Address src, int vector_len); void vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + // FP16 instructions + void vaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vsubsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vmulsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vdivsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vmaxsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vminsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vsqrtsh(XMMRegister dst, XMMRegister src); + void vfmadd132sh(XMMRegister dst, XMMRegister src1, XMMRegister src2); + // Saturating packed insturctions. void vpaddsb(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); void vpaddsw(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); diff --git a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp index 87583ddabd5..7356f5a1913 100644 --- a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp +++ b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp @@ -6675,6 +6675,18 @@ void C2_MacroAssembler::vector_rearrange_int_float(BasicType bt, XMMRegister dst } } +void C2_MacroAssembler::efp16sh(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2) { + switch(opcode) { + case Op_AddHF: vaddsh(dst, src1, src2); break; + case Op_SubHF: vsubsh(dst, src1, src2); break; + case Op_MulHF: vmulsh(dst, src1, src2); break; + case Op_DivHF: vdivsh(dst, src1, src2); break; + case Op_MaxHF: vmaxsh(dst, src1, src2); break; + case Op_MinHF: vminsh(dst, src1, src2); break; + default: assert(false, "%s", NodeClassNames[opcode]); break; + } +} + void C2_MacroAssembler::vector_saturating_op(int ideal_opc, BasicType elem_bt, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc) { switch(elem_bt) { case T_BYTE: diff --git a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp index 6e49cdefa6c..4fe2cc397b5 100644 --- a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp +++ b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -505,6 +505,7 @@ public: void vector_rearrange_int_float(BasicType bt, XMMRegister dst, XMMRegister shuffle, XMMRegister src, int vlen_enc); + void efp16sh(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2); void vgather_subword(BasicType elem_ty, XMMRegister dst, Register base, Register idx_base, Register offset, Register mask, XMMRegister xtmp1, XMMRegister xtmp2, XMMRegister xtmp3, Register rtmp, diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index cc438ce951f..788fe2c6586 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -1027,6 +1027,7 @@ void VM_Version::get_processor_features() { _features &= ~CPU_AVX512_BITALG; _features &= ~CPU_AVX512_IFMA; _features &= ~CPU_APX_F; + _features &= ~CPU_AVX512_FP16; } // Currently APX support is only enabled for targets supporting AVX512VL feature. @@ -1077,6 +1078,7 @@ void VM_Version::get_processor_features() { _features &= ~CPU_AVX512_BITALG; _features &= ~CPU_AVX512_IFMA; _features &= ~CPU_AVX_IFMA; + _features &= ~CPU_AVX512_FP16; } } @@ -3109,6 +3111,9 @@ uint64_t VM_Version::CpuidInfo::feature_flags() const { } if (sef_cpuid7_edx.bits.serialize != 0) result |= CPU_SERIALIZE; + + if (_cpuid_info.sef_cpuid7_edx.bits.avx512_fp16 != 0) + result |= CPU_AVX512_FP16; } // ZX features. diff --git a/src/hotspot/cpu/x86/vm_version_x86.hpp b/src/hotspot/cpu/x86/vm_version_x86.hpp index 004b64ebe6e..9bd3116cf84 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.hpp +++ b/src/hotspot/cpu/x86/vm_version_x86.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -276,7 +276,9 @@ class VM_Version : public Abstract_VM_Version { serialize : 1, : 5, cet_ibt : 1, - : 11; + : 2, + avx512_fp16 : 1, + : 8; } bits; }; @@ -416,8 +418,9 @@ protected: decl(CET_SS, "cet_ss", 57) /* Control Flow Enforcement - Shadow Stack */ \ decl(AVX512_IFMA, "avx512_ifma", 58) /* Integer Vector FMA instructions*/ \ decl(AVX_IFMA, "avx_ifma", 59) /* 256-bit VEX-coded variant of AVX512-IFMA*/ \ - decl(APX_F, "apx_f", 60) /* Intel Advanced Performance Extensions*/\ - decl(SHA512, "sha512", 61) /* SHA512 instructions*/ + decl(APX_F, "apx_f", 60) /* Intel Advanced Performance Extensions*/ \ + decl(SHA512, "sha512", 61) /* SHA512 instructions*/ \ + decl(AVX512_FP16, "avx512_fp16", 62) /* AVX512 FP16 ISA support*/ #define DECLARE_CPU_FEATURE_FLAG(id, name, bit) CPU_##id = (1ULL << bit), CPU_FEATURE_FLAGS(DECLARE_CPU_FEATURE_FLAG) @@ -753,6 +756,7 @@ public: static bool supports_avx512_bitalg() { return (_features & CPU_AVX512_BITALG) != 0; } static bool supports_avx512_vbmi() { return (_features & CPU_AVX512_VBMI) != 0; } static bool supports_avx512_vbmi2() { return (_features & CPU_AVX512_VBMI2) != 0; } + static bool supports_avx512_fp16() { return (_features & CPU_AVX512_FP16) != 0; } static bool supports_hv() { return (_features & CPU_HV) != 0; } static bool supports_serialize() { return (_features & CPU_SERIALIZE) != 0; } static bool supports_f16c() { return (_features & CPU_F16C) != 0; } @@ -840,7 +844,7 @@ public: // For AVX CPUs only. f16c support is disabled if UseAVX == 0. static bool supports_float16() { - return supports_f16c() || supports_avx512vl(); + return supports_f16c() || supports_avx512vl() || supports_avx512_fp16(); } // Check intrinsic support diff --git a/src/hotspot/cpu/x86/x86.ad b/src/hotspot/cpu/x86/x86.ad index 95b761ad44e..8b2c5835544 100644 --- a/src/hotspot/cpu/x86/x86.ad +++ b/src/hotspot/cpu/x86/x86.ad @@ -1,5 +1,5 @@ // -// Copyright (c) 2011, 2024, Oracle and/or its affiliates. All rights reserved. +// Copyright (c) 2011, 2025, Oracle and/or its affiliates. All rights reserved. // DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. // // This code is free software; you can redistribute it and/or modify it @@ -1461,6 +1461,20 @@ bool Matcher::match_rule_supported(int opcode) { return false; } break; + case Op_AddHF: + case Op_DivHF: + case Op_FmaHF: + case Op_MaxHF: + case Op_MinHF: + case Op_MulHF: + case Op_ReinterpretS2HF: + case Op_ReinterpretHF2S: + case Op_SubHF: + case Op_SqrtHF: + if (!VM_Version::supports_avx512_fp16()) { + return false; + } + break; case Op_VectorLoadShuffle: case Op_VectorRearrange: case Op_MulReductionVI: @@ -4521,6 +4535,35 @@ instruct vReplS_reg(vec dst, rRegI src) %{ ins_pipe( pipe_slow ); %} +#ifdef _LP64 +instruct ReplHF_imm(vec dst, immH con, rRegI rtmp) %{ + match(Set dst (Replicate con)); + effect(TEMP rtmp); + format %{ "replicateHF $dst, $con \t! using $rtmp as TEMP" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + BasicType bt = Matcher::vector_element_basic_type(this); + assert(VM_Version::supports_avx512_fp16() && bt == T_SHORT, ""); + __ movl($rtmp$$Register, $con$$constant); + __ evpbroadcastw($dst$$XMMRegister, $rtmp$$Register, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + +instruct ReplHF_reg(vec dst, regF src, rRegI rtmp) %{ + predicate(VM_Version::supports_avx512_fp16() && Matcher::vector_element_basic_type(n) == T_SHORT); + match(Set dst (Replicate src)); + effect(TEMP rtmp); + format %{ "replicateHF $dst, $src \t! using $rtmp as TEMP" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ vmovw($rtmp$$Register, $src$$XMMRegister); + __ evpbroadcastw($dst$$XMMRegister, $rtmp$$Register, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} +#endif + instruct ReplS_mem(vec dst, memory mem) %{ predicate(UseAVX >= 2 && Matcher::vector_element_basic_type(n) == T_SHORT); match(Set dst (Replicate (LoadS mem))); @@ -10837,3 +10880,80 @@ instruct vector_selectfrom_twovectors_reg_evex(vec index, vec src1, vec src2) %} ins_pipe(pipe_slow); %} + +instruct reinterpretS2HF(regF dst, rRegI src) +%{ + match(Set dst (ReinterpretS2HF src)); + format %{ "vmovw $dst, $src" %} + ins_encode %{ + __ vmovw($dst$$XMMRegister, $src$$Register); + %} + ins_pipe(pipe_slow); +%} + +instruct convF2HFAndS2HF(regF dst, regF src) +%{ + match(Set dst (ReinterpretS2HF (ConvF2HF src))); + format %{ "convF2HFAndS2HF $dst, $src" %} + ins_encode %{ + __ vcvtps2ph($dst$$XMMRegister, $src$$XMMRegister, 0x04, Assembler::AVX_128bit); + %} + ins_pipe(pipe_slow); +%} + +instruct convHF2SAndHF2F(regF dst, regF src) +%{ + match(Set dst (ConvHF2F (ReinterpretHF2S src))); + format %{ "convHF2SAndHF2F $dst, $src" %} + ins_encode %{ + __ vcvtph2ps($dst$$XMMRegister, $src$$XMMRegister, Assembler::AVX_128bit); + %} + ins_pipe(pipe_slow); +%} + +instruct reinterpretHF2S(rRegI dst, regF src) +%{ + match(Set dst (ReinterpretHF2S src)); + format %{ "vmovw $dst, $src" %} + ins_encode %{ + __ vmovw($dst$$Register, $src$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} + +instruct scalar_sqrt_HF_reg(regF dst, regF src) +%{ + match(Set dst (SqrtHF src)); + format %{ "scalar_sqrt_fp16 $dst, $src" %} + ins_encode %{ + __ vsqrtsh($dst$$XMMRegister, $src$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} + +instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2) +%{ + match(Set dst (AddHF src1 src2)); + match(Set dst (DivHF src1 src2)); + match(Set dst (MaxHF src1 src2)); + match(Set dst (MinHF src1 src2)); + match(Set dst (MulHF src1 src2)); + match(Set dst (SubHF src1 src2)); + format %{ "scalar_binop_fp16 $dst, $src1, $src2" %} + ins_encode %{ + int opcode = this->ideal_Opcode(); + __ efp16sh(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} + +instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2) +%{ + match(Set dst (FmaHF src2 (Binary dst src1))); + effect(DEF dst); + format %{ "scalar_fma_fp16 $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %} + ins_encode %{ + __ vfmadd132sh($dst$$XMMRegister, $src2$$XMMRegister, $src1$$XMMRegister); + %} + ins_pipe( pipe_slow ); +%} diff --git a/src/hotspot/cpu/x86/x86_64.ad b/src/hotspot/cpu/x86/x86_64.ad index 4667922505c..8cc4a970bfd 100644 --- a/src/hotspot/cpu/x86/x86_64.ad +++ b/src/hotspot/cpu/x86/x86_64.ad @@ -1,5 +1,5 @@ // -// Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved. +// Copyright (c) 2003, 2025, Oracle and/or its affiliates. All rights reserved. // DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. // // This code is free software; you can redistribute it and/or modify it @@ -2382,6 +2382,16 @@ operand immF() interface(CONST_INTER); %} +// Half Float Immediate +operand immH() +%{ + match(ConH); + + op_cost(15); + format %{ %} + interface(CONST_INTER); +%} + // Double Immediate zero operand immD0() %{ @@ -4840,6 +4850,16 @@ instruct loadConF(regF dst, immF con) %{ ins_pipe(pipe_slow); %} +instruct loadConH(regF dst, immH con) %{ + match(Set dst con); + ins_cost(125); + format %{ "movss $dst, [$constantaddress]\t# load from constant table: halffloat=$con" %} + ins_encode %{ + __ movflt($dst$$XMMRegister, $constantaddress($con)); + %} + ins_pipe(pipe_slow); +%} + instruct loadConN0(rRegN dst, immN0 src, rFlagsReg cr) %{ match(Set dst src); effect(KILL cr); @@ -7022,6 +7042,17 @@ instruct castFF(regF dst) ins_pipe(empty); %} +instruct castHH(regF dst) +%{ + match(Set dst (CastHH dst)); + + size(0); + format %{ "# castHH of $dst" %} + ins_encode(/* empty encoding */); + ins_cost(0); + ins_pipe(empty); +%} + instruct castDD(regD dst) %{ match(Set dst (CastDD dst)); diff --git a/src/hotspot/share/adlc/archDesc.cpp b/src/hotspot/share/adlc/archDesc.cpp index f084f506bf5..edb07d2d22c 100644 --- a/src/hotspot/share/adlc/archDesc.cpp +++ b/src/hotspot/share/adlc/archDesc.cpp @@ -1,5 +1,5 @@ // -// Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. +// Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. // DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. // // This code is free software; you can redistribute it and/or modify it @@ -1053,6 +1053,7 @@ const char *ArchDesc::getIdealType(const char *idealOp) { case 'P': return "TypePtr::BOTTOM"; case 'N': return "TypeNarrowOop::BOTTOM"; case 'F': return "Type::FLOAT"; + case 'H': return "Type::HALF_FLOAT"; case 'D': return "Type::DOUBLE"; case 'L': return "TypeLong::LONG"; case 's': return "TypeInt::CC /*flags*/"; @@ -1090,7 +1091,7 @@ void ArchDesc::initBaseOpTypes() { char *ident = (char *)NodeClassNames[j]; if (!strcmp(ident, "ConI") || !strcmp(ident, "ConP") || !strcmp(ident, "ConN") || !strcmp(ident, "ConNKlass") || - !strcmp(ident, "ConF") || !strcmp(ident, "ConD") || + !strcmp(ident, "ConH") || !strcmp(ident, "ConF") || !strcmp(ident, "ConD") || !strcmp(ident, "ConL") || !strcmp(ident, "Con" ) || !strcmp(ident, "Bool")) { constructOperand(ident, true); diff --git a/src/hotspot/share/adlc/forms.cpp b/src/hotspot/share/adlc/forms.cpp index c34a73ea1e1..e2265f70ed9 100644 --- a/src/hotspot/share/adlc/forms.cpp +++ b/src/hotspot/share/adlc/forms.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -220,6 +220,7 @@ Form::DataType Form::ideal_to_const_type(const char *name) const { if (strcmp(name,"ConNKlass")==0) return Form::idealNKlass; if (strcmp(name,"ConL")==0) return Form::idealL; if (strcmp(name,"ConF")==0) return Form::idealF; + if (strcmp(name,"ConH")==0) return Form::idealH; if (strcmp(name,"ConD")==0) return Form::idealD; if (strcmp(name,"Bool")==0) return Form::idealI; diff --git a/src/hotspot/share/adlc/forms.hpp b/src/hotspot/share/adlc/forms.hpp index a82b9bbb338..0b673bf8542 100644 --- a/src/hotspot/share/adlc/forms.hpp +++ b/src/hotspot/share/adlc/forms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -183,7 +183,8 @@ public: idealS = 8, // String type idealN = 9, // Narrow oop types idealNKlass = 10, // Narrow klass types - idealV = 11 // Vector type + idealV = 11, // Vector type + idealH = 12 // HalfFloat type }; // Convert ideal name to a DataType, return DataType::none if not a 'ConX' Form::DataType ideal_to_const_type(const char *ideal_type_name) const; diff --git a/src/hotspot/share/adlc/formssel.cpp b/src/hotspot/share/adlc/formssel.cpp index dfa414ef564..f18e9eddba5 100644 --- a/src/hotspot/share/adlc/formssel.cpp +++ b/src/hotspot/share/adlc/formssel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1998, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1998, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -1088,7 +1088,7 @@ uint InstructForm::reloc(FormDict &globals) { } else if ( oper ) { // floats and doubles loaded out of method's constant pool require reloc info Form::DataType type = oper->is_base_constant(globals); - if ( (type == Form::idealF) || (type == Form::idealD) ) { + if ( (type == Form::idealH) || (type == Form::idealF) || (type == Form::idealD) ) { ++reloc_entries; } } @@ -1099,7 +1099,7 @@ uint InstructForm::reloc(FormDict &globals) { // !!!!! // Check for any component being an immediate float or double. Form::DataType data_type = is_chain_of_constant(globals); - if( data_type==idealD || data_type==idealF ) { + if( data_type==idealH || data_type==idealD || data_type==idealF ) { reloc_entries++; } @@ -2662,6 +2662,7 @@ void OperandForm::format_constant(FILE *fp, uint const_index, uint const_type) { case Form::idealN: fprintf(fp," if (_c%d) _c%d->dump_on(st);\n", const_index, const_index); break; case Form::idealL: fprintf(fp," st->print(\"#\" INT64_FORMAT, (int64_t)_c%d);\n", const_index); break; case Form::idealF: fprintf(fp," st->print(\"#%%f\", _c%d);\n", const_index); break; + case Form::idealH: fprintf(fp," st->print(\"#%%d\", _c%d);\n", const_index); break; case Form::idealD: fprintf(fp," st->print(\"#%%f\", _c%d);\n", const_index); break; default: assert( false, "ShouldNotReachHere()"); @@ -2743,6 +2744,7 @@ void OperandForm::access_constant(FILE *fp, FormDict &globals, case idealP: fprintf(fp,"_c%d->get_con()",const_index); break; case idealL: fprintf(fp,"_c%d", const_index); break; case idealF: fprintf(fp,"_c%d", const_index); break; + case idealH: fprintf(fp,"_c%d", const_index); break; case idealD: fprintf(fp,"_c%d", const_index); break; default: assert( false, "ShouldNotReachHere()"); @@ -3953,11 +3955,12 @@ bool MatchNode::equivalent(FormDict &globals, MatchNode *mNode2) { // which could be swapped. void MatchNode::count_commutative_op(int& count) { static const char *commut_op_list[] = { - "AddI","AddL","AddF","AddD", + "AddI","AddL","AddHF","AddF","AddD", "AndI","AndL", - "MaxI","MinI","MaxF","MinF","MaxD","MinD", - "MulI","MulL","MulF","MulD", - "OrI","OrL", "XorI","XorL", + "MaxI","MinI","MaxHF","MinHF","MaxF","MinF","MaxD","MinD", + "MulI","MulL","MulHF","MulF","MulD", + "OrI","OrL", + "XorI","XorL" "UMax","UMin" }; @@ -4193,6 +4196,7 @@ int MatchRule::is_expensive() const { if( strcmp(opType,"AtanD")==0 || strcmp(opType,"DivD")==0 || strcmp(opType,"DivF")==0 || + strcmp(opType,"DivHF")==0 || strcmp(opType,"DivI")==0 || strcmp(opType,"Log10D")==0 || strcmp(opType,"ModD")==0 || @@ -4200,6 +4204,7 @@ int MatchRule::is_expensive() const { strcmp(opType,"ModI")==0 || strcmp(opType,"SqrtD")==0 || strcmp(opType,"SqrtF")==0 || + strcmp(opType,"SqrtHF")==0 || strcmp(opType,"TanD")==0 || strcmp(opType,"ConvD2F")==0 || strcmp(opType,"ConvD2I")==0 || @@ -4219,6 +4224,7 @@ int MatchRule::is_expensive() const { strcmp(opType,"DecodeNKlass")==0 || strcmp(opType,"FmaD") == 0 || strcmp(opType,"FmaF") == 0 || + strcmp(opType,"FmaHF") == 0 || strcmp(opType,"RoundDouble")==0 || strcmp(opType,"RoundDoubleMode")==0 || strcmp(opType,"RoundFloat")==0 || diff --git a/src/hotspot/share/adlc/output_c.cpp b/src/hotspot/share/adlc/output_c.cpp index cc6ed278b49..0620f2f4496 100644 --- a/src/hotspot/share/adlc/output_c.cpp +++ b/src/hotspot/share/adlc/output_c.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1998, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1998, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -2421,6 +2421,8 @@ private: if( _constant_status == LITERAL_NOT_SEEN ) { if ( _constant_type == Form::idealD ) { fprintf(_fp,"->constantD()"); + } else if ( _constant_type == Form::idealH ) { + fprintf(_fp,"->constantH()"); } else if ( _constant_type == Form::idealF ) { fprintf(_fp,"->constantF()"); } else if ( _constant_type == Form::idealL ) { @@ -3789,6 +3791,8 @@ static void path_to_constant(FILE *fp, FormDict &globals, fprintf(fp, "_leaf->bottom_type()->is_narrowoop()"); } else if ( (strcmp(optype,"ConNKlass") == 0) ) { fprintf(fp, "_leaf->bottom_type()->is_narrowklass()"); + } else if ( (strcmp(optype,"ConH") == 0) ) { + fprintf(fp, "_leaf->geth()"); } else if ( (strcmp(optype,"ConF") == 0) ) { fprintf(fp, "_leaf->getf()"); } else if ( (strcmp(optype,"ConD") == 0) ) { diff --git a/src/hotspot/share/adlc/output_h.cpp b/src/hotspot/share/adlc/output_h.cpp index d6767bc1f7e..a4ab29008f0 100644 --- a/src/hotspot/share/adlc/output_h.cpp +++ b/src/hotspot/share/adlc/output_h.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1998, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1998, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -233,6 +233,10 @@ static void declareConstStorage(FILE *fp, FormDict &globals, OperandForm *oper) if (i > 0) fprintf(fp,", "); fprintf(fp," jfloat _c%d;\n", i); } + else if (!strcmp(type, "ConH")) { + if (i > 0) fprintf(fp,", "); + fprintf(fp," jshort _c%d;\n", i); + } else if (!strcmp(type, "ConD")) { if (i > 0) fprintf(fp,", "); fprintf(fp," jdouble _c%d;\n", i); @@ -269,6 +273,10 @@ static void declareConstStorage(FILE *fp, FormDict &globals, OperandForm *oper) fprintf(fp," jlong _c%d;\n", i); i++; } + else if (!strcmp(comp->base_type(globals), "ConH")) { + fprintf(fp," jshort _c%d;\n", i); + i++; + } else if (!strcmp(comp->base_type(globals), "ConF")) { fprintf(fp," jfloat _c%d;\n", i); i++; @@ -314,6 +322,7 @@ static void defineConstructor(FILE *fp, const char *name, uint num_consts, case Form::idealNKlass : { fprintf(fp,"const TypeNarrowKlass *c%d", i); break; } case Form::idealP : { fprintf(fp,"const TypePtr *c%d", i); break; } case Form::idealL : { fprintf(fp,"jlong c%d", i); break; } + case Form::idealH : { fprintf(fp,"jshort c%d", i); break; } case Form::idealF : { fprintf(fp,"jfloat c%d", i); break; } case Form::idealD : { fprintf(fp,"jdouble c%d", i); break; } default: @@ -403,6 +412,11 @@ static uint dump_spec_constant(FILE *fp, const char *ideal_type, uint i, Operand fprintf(fp," st->print(\"/0x%%08x\", _c%d);\n", i); ++i; } + else if (!strcmp(ideal_type, "ConH")) { + fprintf(fp," st->print(\"#%%d\", _c%d);\n", i); + fprintf(fp," st->print(\"/0x%%08x\", _c%d);\n", i); + ++i; + } else if (!strcmp(ideal_type, "ConP")) { fprintf(fp," _c%d->dump_on(st);\n", i); ++i; @@ -1281,6 +1295,7 @@ void ArchDesc::declareClasses(FILE *fp) { case Form::idealF: type = "Type::FLOAT"; break; case Form::idealD: type = "Type::DOUBLE"; break; case Form::idealL: type = "TypeLong::LONG"; break; + case Form::idealH: type = "Type::HALF_FLOAT"; break; case Form::none: // fall through default: assert( false, "No support for this type of stackSlot"); @@ -1425,6 +1440,14 @@ void ArchDesc::declareClasses(FILE *fp) { fprintf(fp, " return _c0;"); fprintf(fp, " }\n"); } + else if (!strcmp(oper->ideal_type(_globalNames), "ConH")) { + fprintf(fp," virtual intptr_t constant() const {"); + fprintf(fp, " ShouldNotReachHere(); return 0; "); + fprintf(fp, " }\n"); + fprintf(fp," virtual jshort constantH() const {"); + fprintf(fp, " return (jshort)_c0;"); + fprintf(fp, " }\n"); + } else if (!strcmp(oper->ideal_type(_globalNames), "ConF")) { fprintf(fp," virtual intptr_t constant() const {"); fprintf(fp, " ShouldNotReachHere(); return 0; "); @@ -1897,6 +1920,9 @@ void ArchDesc::declareClasses(FILE *fp) { case Form::idealD: fprintf(fp," return TypeD::make(opnd_array(1)->constantD());\n"); break; + case Form::idealH: + fprintf(fp," return TypeH::make(opnd_array(1)->constantH());\n"); + break; case Form::idealF: fprintf(fp," return TypeF::make(opnd_array(1)->constantF());\n"); break; diff --git a/src/hotspot/share/classfile/vmIntrinsics.hpp b/src/hotspot/share/classfile/vmIntrinsics.hpp index 16f6ff024d0..cf68217ae2c 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.hpp +++ b/src/hotspot/share/classfile/vmIntrinsics.hpp @@ -924,19 +924,34 @@ class methodHandle; do_signature(getAndAddShort_signature, "(Ljava/lang/Object;JS)S" ) \ do_intrinsic(_getAndSetInt, jdk_internal_misc_Unsafe, getAndSetInt_name, getAndSetInt_signature, F_R) \ do_name( getAndSetInt_name, "getAndSetInt") \ - do_alias( getAndSetInt_signature, /*"(Ljava/lang/Object;JI)I"*/ getAndAddInt_signature) \ + do_alias( getAndSetInt_signature, /*"(Ljava/lang/Object;JI)I"*/ getAndAddInt_signature) \ do_intrinsic(_getAndSetLong, jdk_internal_misc_Unsafe, getAndSetLong_name, getAndSetLong_signature, F_R) \ do_name( getAndSetLong_name, "getAndSetLong") \ - do_alias( getAndSetLong_signature, /*"(Ljava/lang/Object;JJ)J"*/ getAndAddLong_signature) \ + do_alias( getAndSetLong_signature, /*"(Ljava/lang/Object;JJ)J"*/ getAndAddLong_signature)\ do_intrinsic(_getAndSetByte, jdk_internal_misc_Unsafe, getAndSetByte_name, getAndSetByte_signature, F_R) \ do_name( getAndSetByte_name, "getAndSetByte") \ - do_alias( getAndSetByte_signature, /*"(Ljava/lang/Object;JB)B"*/ getAndAddByte_signature) \ + do_alias( getAndSetByte_signature, /*"(Ljava/lang/Object;JB)B"*/ getAndAddByte_signature)\ do_intrinsic(_getAndSetShort, jdk_internal_misc_Unsafe, getAndSetShort_name, getAndSetShort_signature, F_R) \ - do_name( getAndSetShort_name, "getAndSetShort") \ - do_alias( getAndSetShort_signature, /*"(Ljava/lang/Object;JS)S"*/ getAndAddShort_signature) \ - do_intrinsic(_getAndSetReference, jdk_internal_misc_Unsafe, getAndSetReference_name, getAndSetReference_signature, F_R) \ - do_name( getAndSetReference_name, "getAndSetReference") \ + do_name( getAndSetShort_name, "getAndSetShort") \ + do_alias( getAndSetShort_signature, /*"(Ljava/lang/Object;JS)S"*/ getAndAddShort_signature) \ + do_intrinsic(_getAndSetReference, jdk_internal_misc_Unsafe, getAndSetReference_name, getAndSetReference_signature, F_R) \ + do_name( getAndSetReference_name, "getAndSetReference") \ do_signature(getAndSetReference_signature, "(Ljava/lang/Object;JLjava/lang/Object;)Ljava/lang/Object;" ) \ + \ + /* Float16Math API intrinsification support */ \ + /* Float16 signatures */ \ + do_signature(float16_unary_math_op_sig, "(Ljava/lang/Class;" \ + "Ljava/lang/Object;" \ + "Ljava/util/function/UnaryOperator;)" \ + "Ljava/lang/Object;") \ + do_signature(float16_ternary_math_op_sig, "(Ljava/lang/Class;" \ + "Ljava/lang/Object;" \ + "Ljava/lang/Object;" \ + "Ljava/lang/Object;" \ + "Ljdk/internal/vm/vector/Float16Math$TernaryOperator;)" \ + "Ljava/lang/Object;") \ + do_intrinsic(_sqrt_float16, jdk_internal_vm_vector_Float16Math, sqrt_name, float16_unary_math_op_sig, F_S) \ + do_intrinsic(_fma_float16, jdk_internal_vm_vector_Float16Math, fma_name, float16_ternary_math_op_sig, F_S) \ \ /* Vector API intrinsification support */ \ \ diff --git a/src/hotspot/share/classfile/vmSymbols.hpp b/src/hotspot/share/classfile/vmSymbols.hpp index 622fa8640a6..e66066738ef 100644 --- a/src/hotspot/share/classfile/vmSymbols.hpp +++ b/src/hotspot/share/classfile/vmSymbols.hpp @@ -91,7 +91,8 @@ class SerializeClosure; template(java_lang_Long_LongCache, "java/lang/Long$LongCache") \ template(java_lang_Void, "java/lang/Void") \ \ - template(jdk_internal_vm_vector_VectorSupport, "jdk/internal/vm/vector/VectorSupport") \ + template(jdk_internal_vm_vector_VectorSupport, "jdk/internal/vm/vector/VectorSupport") \ + template(jdk_internal_vm_vector_Float16Math, "jdk/internal/vm/vector/Float16Math") \ template(jdk_internal_vm_vector_VectorPayload, "jdk/internal/vm/vector/VectorSupport$VectorPayload") \ template(jdk_internal_vm_vector_Vector, "jdk/internal/vm/vector/VectorSupport$Vector") \ template(jdk_internal_vm_vector_VectorMask, "jdk/internal/vm/vector/VectorSupport$VectorMask") \ diff --git a/src/hotspot/share/opto/addnode.cpp b/src/hotspot/share/opto/addnode.cpp index 928e06ee938..8406fe8b69e 100644 --- a/src/hotspot/share/opto/addnode.cpp +++ b/src/hotspot/share/opto/addnode.cpp @@ -32,6 +32,7 @@ #include "opto/mulnode.hpp" #include "opto/phaseX.hpp" #include "opto/subnode.hpp" +#include "runtime/stubRoutines.hpp" // Portions of code courtesy of Clifford Click @@ -555,6 +556,22 @@ Node *AddFNode::Ideal(PhaseGVN *phase, bool can_reshape) { return commute(phase, this) ? this : nullptr; } +//============================================================================= +//------------------------------add_of_identity-------------------------------- +// Check for addition of the identity +const Type* AddHFNode::add_of_identity(const Type* t1, const Type* t2) const { + return nullptr; +} + +// Supplied function returns the sum of the inputs. +// This also type-checks the inputs for sanity. Guaranteed never to +// be passed a TOP or BOTTOM type, these are filtered out by pre-check. +const Type* AddHFNode::add_ring(const Type* t0, const Type* t1) const { + if (!t0->isa_half_float_constant() || !t1->isa_half_float_constant()) { + return bottom_type(); + } + return TypeH::make(t0->getf() + t1->getf()); +} //============================================================================= //------------------------------add_of_identity-------------------------------- @@ -1504,6 +1521,33 @@ Node* MaxNode::Identity(PhaseGVN* phase) { return AddNode::Identity(phase); } +//------------------------------add_ring--------------------------------------- +const Type* MinHFNode::add_ring(const Type* t0, const Type* t1) const { + const TypeH* r0 = t0->isa_half_float_constant(); + const TypeH* r1 = t1->isa_half_float_constant(); + if (r0 == nullptr || r1 == nullptr) { + return bottom_type(); + } + + if (r0->is_nan()) { + return r0; + } + if (r1->is_nan()) { + return r1; + } + + float f0 = r0->getf(); + float f1 = r1->getf(); + if (f0 != 0.0f || f1 != 0.0f) { + return f0 < f1 ? r0 : r1; + } + + // As per IEEE 754 specification, floating point comparison consider +ve and -ve + // zeros as equals. Thus, performing signed integral comparison for min value + // detection. + return (jint_cast(f0) < jint_cast(f1)) ? r0 : r1; +} + //------------------------------add_ring--------------------------------------- const Type* MinFNode::add_ring(const Type* t0, const Type* t1 ) const { const TypeF* r0 = t0->isa_float_constant(); @@ -1554,6 +1598,34 @@ const Type* MinDNode::add_ring(const Type* t0, const Type* t1) const { return (jlong_cast(d0) < jlong_cast(d1)) ? r0 : r1; } +//------------------------------add_ring--------------------------------------- +const Type* MaxHFNode::add_ring(const Type* t0, const Type* t1) const { + const TypeH* r0 = t0->isa_half_float_constant(); + const TypeH* r1 = t1->isa_half_float_constant(); + if (r0 == nullptr || r1 == nullptr) { + return bottom_type(); + } + + if (r0->is_nan()) { + return r0; + } + if (r1->is_nan()) { + return r1; + } + + float f0 = r0->getf(); + float f1 = r1->getf(); + if (f0 != 0.0f || f1 != 0.0f) { + return f0 > f1 ? r0 : r1; + } + + // As per IEEE 754 specification, floating point comparison consider +ve and -ve + // zeros as equals. Thus, performing signed integral comparison for max value + // detection. + return (jint_cast(f0) > jint_cast(f1)) ? r0 : r1; +} + + //------------------------------add_ring--------------------------------------- const Type* MaxFNode::add_ring(const Type* t0, const Type* t1) const { const TypeF* r0 = t0->isa_float_constant(); diff --git a/src/hotspot/share/opto/addnode.hpp b/src/hotspot/share/opto/addnode.hpp index c409fb8cea8..456a8d9f9a0 100644 --- a/src/hotspot/share/opto/addnode.hpp +++ b/src/hotspot/share/opto/addnode.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -155,6 +155,22 @@ public: virtual uint ideal_reg() const { return Op_RegD; } }; +//------------------------------AddHFNode--------------------------------------- +// Add 2 half-precision floats +class AddHFNode : public AddNode { +public: + AddHFNode(Node* in1, Node* in2) : AddNode(in1,in2) {} + virtual int Opcode() const; + virtual const Type* add_of_identity(const Type* t1, const Type* t2) const; + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeH::ZERO; } + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } + virtual Node* Identity(PhaseGVN* phase) { return this; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------AddPNode--------------------------------------- // Add pointer plus integer to get pointer. NOT commutative, really. // So not really an AddNode. Lives here, because people associate it with @@ -402,6 +418,34 @@ public: int min_opcode() const { return Op_MinF; } }; +//------------------------------MaxHFNode-------------------------------------- +// Maximum of 2 half floats. +class MaxHFNode : public MaxNode { +public: + MaxHFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} + virtual int Opcode() const; + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeH::NEG_INF; } + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } +}; + +//------------------------------MinHFNode--------------------------------------- +// Minimum of 2 half floats. +class MinHFNode : public MaxNode { +public: + MinHFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} + virtual int Opcode() const; + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeH::POS_INF; } + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } +}; + //------------------------------MaxDNode--------------------------------------- // Maximum of 2 doubles. class MaxDNode : public MaxNode { diff --git a/src/hotspot/share/opto/c2compiler.cpp b/src/hotspot/share/opto/c2compiler.cpp index e126697ed05..c5b6d5a4895 100644 --- a/src/hotspot/share/opto/c2compiler.cpp +++ b/src/hotspot/share/opto/c2compiler.cpp @@ -352,6 +352,12 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) { case vmIntrinsics::_floatToFloat16: if (!Matcher::match_rule_supported(Op_ConvF2HF)) return false; break; + case vmIntrinsics::_sqrt_float16: + if (!Matcher::match_rule_supported(Op_SqrtHF)) return false; + break; + case vmIntrinsics::_fma_float16: + if (!Matcher::match_rule_supported(Op_FmaHF)) return false; + break; /* CompareAndSet, Object: */ case vmIntrinsics::_compareAndSetReference: diff --git a/src/hotspot/share/opto/castnode.cpp b/src/hotspot/share/opto/castnode.cpp index 1644b997fb8..68ba291986b 100644 --- a/src/hotspot/share/opto/castnode.cpp +++ b/src/hotspot/share/opto/castnode.cpp @@ -475,6 +475,8 @@ Node* ConstraintCastNode::make_cast_for_type(Node* c, Node* in, const Type* type return new CastIINode(c, in, type, dependency, false, types); } else if (type->isa_long()) { return new CastLLNode(c, in, type, dependency, types); + } else if (type->isa_half_float()) { + return new CastHHNode(c, in, type, dependency, types); } else if (type->isa_float()) { return new CastFFNode(c, in, type, dependency, types); } else if (type->isa_double()) { diff --git a/src/hotspot/share/opto/castnode.hpp b/src/hotspot/share/opto/castnode.hpp index e6ef7b06065..1b848e5efdf 100644 --- a/src/hotspot/share/opto/castnode.hpp +++ b/src/hotspot/share/opto/castnode.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -143,6 +143,17 @@ public: virtual uint ideal_reg() const { return Op_RegL; } }; +class CastHHNode: public ConstraintCastNode { +public: + CastHHNode(Node* ctrl, Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr) + : ConstraintCastNode(ctrl, n, t, dependency, types) { + assert(ctrl != nullptr, "control must be set"); + init_class_id(Class_CastHH); + } + virtual int Opcode() const; + virtual uint ideal_reg() const { return in(1)->ideal_reg(); } +}; + class CastFFNode: public ConstraintCastNode { public: CastFFNode(Node* ctrl, Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr) diff --git a/src/hotspot/share/opto/classes.hpp b/src/hotspot/share/opto/classes.hpp index 60ee3e01137..918d8156b5f 100644 --- a/src/hotspot/share/opto/classes.hpp +++ b/src/hotspot/share/opto/classes.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -36,6 +36,7 @@ macro(AddF) macro(AddI) macro(AddL) macro(AddP) +macro(AddHF) macro(Allocate) macro(AllocateArray) macro(AndI) @@ -64,6 +65,7 @@ macro(CallLeafVector) macro(CallRuntime) macro(CallStaticJava) macro(CastDD) +macro(CastHH) macro(CastFF) macro(CastII) macro(CastLL) @@ -132,6 +134,7 @@ macro(Con) macro(ConN) macro(ConNKlass) macro(ConD) +macro(ConH) macro(ConF) macro(ConI) macro(ConL) @@ -166,6 +169,7 @@ macro(CountTrailingZerosV) macro(CreateEx) macro(DecodeN) macro(DecodeNKlass) +macro(DivHF) macro(DivD) macro(DivF) macro(DivI) @@ -184,6 +188,7 @@ macro(FastLock) macro(FastUnlock) macro(FmaD) macro(FmaF) +macro(FmaHF) macro(ForwardException) macro(Goto) macro(Halt) @@ -222,6 +227,7 @@ macro(MachProj) macro(MulAddS2I) macro(MaxI) macro(MaxL) +macro(MaxHF) macro(MaxD) macro(MaxF) macro(MemBarAcquire) @@ -237,6 +243,7 @@ macro(MemBarStoreStore) macro(MergeMem) macro(MinI) macro(MinL) +macro(MinHF) macro(MinF) macro(MinD) macro(ModD) @@ -253,6 +260,7 @@ macro(IsInfiniteF) macro(IsFiniteF) macro(IsInfiniteD) macro(IsFiniteD) +macro(MulHF) macro(MulD) macro(MulF) macro(MulHiL) @@ -338,6 +346,7 @@ macro(SignumVF) macro(SignumVD) macro(SqrtD) macro(SqrtF) +macro(SqrtHF) macro(RoundF) macro(RoundD) macro(Start) @@ -357,6 +366,7 @@ macro(StrEquals) macro(StrIndexOf) macro(StrIndexOfChar) macro(StrInflatedCopy) +macro(SubHF) macro(SubD) macro(SubF) macro(SubI) @@ -485,6 +495,8 @@ macro(ExtractF) macro(ExtractD) macro(Digit) macro(LowerCase) +macro(ReinterpretS2HF) +macro(ReinterpretHF2S) macro(UpperCase) macro(Whitespace) macro(SelectFromTwoVector) diff --git a/src/hotspot/share/opto/connode.cpp b/src/hotspot/share/opto/connode.cpp index db3875dd95d..3538342b5d4 100644 --- a/src/hotspot/share/opto/connode.cpp +++ b/src/hotspot/share/opto/connode.cpp @@ -43,6 +43,9 @@ uint ConNode::hash() const { //------------------------------make------------------------------------------- ConNode *ConNode::make(const Type *t) { + if (t->isa_half_float_constant()) { + return new ConHNode( t->is_half_float_constant() ); + } switch( t->basic_type() ) { case T_INT: return new ConINode( t->is_int() ); case T_LONG: return new ConLNode( t->is_long() ); diff --git a/src/hotspot/share/opto/connode.hpp b/src/hotspot/share/opto/connode.hpp index 618326ec527..3b7657320e2 100644 --- a/src/hotspot/share/opto/connode.hpp +++ b/src/hotspot/share/opto/connode.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -115,6 +115,19 @@ public: }; +//------------------------------ConHNode--------------------------------------- +// Simple half float constants +class ConHNode : public ConNode { +public: + ConHNode(const TypeH* t) : ConNode(t) {} + virtual int Opcode() const; + + // Factory method: + static ConHNode* make(float con) { + return new ConHNode(TypeH::make(con)); + } +}; + //------------------------------ConFNode--------------------------------------- // Simple float constants class ConFNode : public ConNode { diff --git a/src/hotspot/share/opto/constantTable.cpp b/src/hotspot/share/opto/constantTable.cpp index ddab9e95e36..aba71e45e1c 100644 --- a/src/hotspot/share/opto/constantTable.cpp +++ b/src/hotspot/share/opto/constantTable.cpp @@ -49,6 +49,7 @@ bool ConstantTable::Constant::operator==(const Constant& other) { // For floating point values we compare the bit pattern. switch (type()) { + case T_SHORT: return (_v._value.i == other._v._value.i); case T_INT: return (_v._value.i == other._v._value.i); case T_FLOAT: return jint_cast(_v._value.f) == jint_cast(other._v._value.f); case T_LONG: return (_v._value.j == other._v._value.j); @@ -102,6 +103,7 @@ static int constant_size(ConstantTable::Constant* con) { return con->get_array()->length(); } switch (con->type()) { + case T_SHORT: return sizeof(jint ); case T_INT: return sizeof(jint ); case T_LONG: return sizeof(jlong ); case T_FLOAT: return sizeof(jfloat ); @@ -167,6 +169,7 @@ bool ConstantTable::emit(C2_MacroAssembler* masm) const { constant_addr = masm->array_constant(con.get_array(), con.alignment()); } else { switch (con.type()) { + case T_SHORT: constant_addr = masm->int_constant( con.get_jint() ); break; case T_INT: constant_addr = masm->int_constant( con.get_jint() ); break; case T_LONG: constant_addr = masm->long_constant( con.get_jlong() ); break; case T_FLOAT: constant_addr = masm->float_constant( con.get_jfloat() ); break; @@ -281,6 +284,7 @@ ConstantTable::Constant ConstantTable::add(MachConstantNode* n, MachOper* oper) BasicType type = oper->type()->basic_type(); switch (type) { case T_LONG: value.j = oper->constantL(); break; + case T_SHORT: value.i = oper->constantH(); break; case T_INT: value.i = oper->constant(); break; case T_FLOAT: value.f = oper->constantF(); break; case T_DOUBLE: value.d = oper->constantD(); break; diff --git a/src/hotspot/share/opto/convertnode.cpp b/src/hotspot/share/opto/convertnode.cpp index 07956d56f1c..bdfb211da83 100644 --- a/src/hotspot/share/opto/convertnode.cpp +++ b/src/hotspot/share/opto/convertnode.cpp @@ -26,8 +26,10 @@ #include "opto/castnode.hpp" #include "opto/connode.hpp" #include "opto/convertnode.hpp" +#include "opto/divnode.hpp" #include "opto/matcher.hpp" #include "opto/movenode.hpp" +#include "opto/mulnode.hpp" #include "opto/phaseX.hpp" #include "opto/subnode.hpp" #include "runtime/stubRoutines.hpp" @@ -248,6 +250,37 @@ const Type* ConvF2HFNode::Value(PhaseGVN* phase) const { return TypeInt::make( StubRoutines::f2hf(tf->getf()) ); } +//------------------------------Ideal------------------------------------------ +Node* ConvF2HFNode::Ideal(PhaseGVN* phase, bool can_reshape) { + // Float16 instance encapsulates a short field holding IEEE 754 + // binary16 value. On unboxing, this short field is loaded into a + // GPR register while FP operation operates over floating point + // registers. ConvHF2F converts incoming short value to a FP32 value + // to perform operation at FP32 granularity. However, if target + // support FP16 ISA we can save this redundant up casting and + // optimize the graph pallet using following transformation. + // + // ConvF2HF(FP32BinOp(ConvHF2F(x), ConvHF2F(y))) => + // ReinterpretHF2S(FP16BinOp(ReinterpretS2HF(x), ReinterpretS2HF(y))) + // + // Please note we need to inject appropriate reinterpretation + // IR to move the values b/w GPR and floating point register + // before and after FP16 operation. + + if (Float16NodeFactory::is_float32_binary_oper(in(1)->Opcode()) && + in(1)->in(1)->Opcode() == Op_ConvHF2F && + in(1)->in(2)->Opcode() == Op_ConvHF2F) { + if (Matcher::match_rule_supported(Float16NodeFactory::get_float16_binary_oper(in(1)->Opcode())) && + Matcher::match_rule_supported(Op_ReinterpretS2HF) && + Matcher::match_rule_supported(Op_ReinterpretHF2S)) { + Node* in1 = phase->transform(new ReinterpretS2HFNode(in(1)->in(1)->in(1))); + Node* in2 = phase->transform(new ReinterpretS2HFNode(in(1)->in(2)->in(1))); + Node* binop = phase->transform(Float16NodeFactory::make(in(1)->Opcode(), in(1)->in(0), in1, in2)); + return new ReinterpretHF2SNode(binop); + } + } + return nullptr; +} //============================================================================= //------------------------------Value------------------------------------------ const Type* ConvF2INode::Value(PhaseGVN* phase) const { @@ -896,3 +929,75 @@ const Type* RoundDoubleModeNode::Value(PhaseGVN* phase) const { return Type::DOUBLE; } //============================================================================= + +const Type* ReinterpretS2HFNode::Value(PhaseGVN* phase) const { + const Type* type = phase->type(in(1)); + // Convert short constant value to a Half Float constant value + if ((type->isa_int() && type->is_int()->is_con())) { + jshort hfval = type->is_int()->get_con(); + return TypeH::make(hfval); + } + return Type::HALF_FLOAT; +} + +Node* ReinterpretS2HFNode::Identity(PhaseGVN* phase) { + if (in(1)->Opcode() == Op_ReinterpretHF2S) { + assert(in(1)->in(1)->bottom_type()->isa_half_float(), ""); + return in(1)->in(1); + } + return this; +} + +const Type* ReinterpretHF2SNode::Value(PhaseGVN* phase) const { + const Type* type = phase->type(in(1)); + // Convert Half float constant value to short constant value. + if (type->isa_half_float_constant()) { + jshort hfval = type->is_half_float_constant()->_f; + return TypeInt::make(hfval); + } + return TypeInt::SHORT; +} + +bool Float16NodeFactory::is_float32_binary_oper(int opc) { + switch(opc) { + case Op_AddF: + case Op_SubF: + case Op_MulF: + case Op_DivF: + case Op_MaxF: + case Op_MinF: + return true; + default: + return false; + } +} + +int Float16NodeFactory::get_float16_binary_oper(int opc) { + switch(opc) { + case Op_AddF: + return Op_AddHF; + case Op_SubF: + return Op_SubHF; + case Op_MulF: + return Op_MulHF; + case Op_DivF: + return Op_DivHF; + case Op_MaxF: + return Op_MaxHF; + case Op_MinF: + return Op_MinHF; + default: ShouldNotReachHere(); + } +} + +Node* Float16NodeFactory::make(int opc, Node* c, Node* in1, Node* in2) { + switch(opc) { + case Op_AddF: return new AddHFNode(in1, in2); + case Op_SubF: return new SubHFNode(in1, in2); + case Op_MulF: return new MulHFNode(in1, in2); + case Op_DivF: return new DivHFNode(c, in1, in2); + case Op_MaxF: return new MaxHFNode(in1, in2); + case Op_MinF: return new MinHFNode(in1, in2); + default: ShouldNotReachHere(); + } +} diff --git a/src/hotspot/share/opto/convertnode.hpp b/src/hotspot/share/opto/convertnode.hpp index 9438176a9f9..64b2d2571b2 100644 --- a/src/hotspot/share/opto/convertnode.hpp +++ b/src/hotspot/share/opto/convertnode.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -112,6 +112,7 @@ class ConvF2HFNode : public ConvertNode { virtual int Opcode() const; virtual const Type* in_type() const { return TypeInt::FLOAT; } virtual const Type* Value(PhaseGVN* phase) const; + virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); }; //------------------------------ConvF2INode------------------------------------ @@ -213,6 +214,30 @@ class ConvL2INode : public ConvertNode { virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); }; + +//-----------------------------ReinterpretS2HFNode --------------------------- +// Reinterpret Short to Half Float +class ReinterpretS2HFNode : public Node { + public: + ReinterpretS2HFNode(Node* in1) : Node(nullptr, in1) {} + virtual int Opcode() const; + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual const Type* Value(PhaseGVN* phase) const; + virtual Node* Identity(PhaseGVN* phase); + virtual uint ideal_reg() const { return Op_RegF; } +}; + +//-----------------------------ReinterpretS2HFNode --------------------------- +// Reinterpret Half Float to Short +class ReinterpretHF2SNode : public Node { + public: + ReinterpretHF2SNode(Node* in1) : Node(nullptr, in1) {} + virtual int Opcode() const; + virtual const Type* Value(PhaseGVN* phase) const; + virtual const Type* bottom_type() const { return TypeInt::SHORT; } + virtual uint ideal_reg() const { return Op_RegI; } +}; + class RoundDNode : public Node { public: RoundDNode(Node* in1) : Node(nullptr, in1) {} @@ -269,5 +294,11 @@ class RoundDoubleModeNode: public Node { virtual const Type* Value(PhaseGVN* phase) const; }; +class Float16NodeFactory { + public: + static bool is_float32_binary_oper(int opc); + static int get_float16_binary_oper(int opc); + static Node* make(int opc, Node* c, Node* in1, Node* in2); +}; #endif // SHARE_OPTO_CONVERTNODE_HPP diff --git a/src/hotspot/share/opto/divnode.cpp b/src/hotspot/share/opto/divnode.cpp index ef27a3d7a14..bb66ad47ed7 100644 --- a/src/hotspot/share/opto/divnode.cpp +++ b/src/hotspot/share/opto/divnode.cpp @@ -805,6 +805,115 @@ Node *DivFNode::Ideal(PhaseGVN *phase, bool can_reshape) { // return multiplication by the reciprocal return (new MulFNode(in(1), phase->makecon(TypeF::make(reciprocal)))); } +//============================================================================= +//------------------------------Value------------------------------------------ +// An DivHFNode divides its inputs. The third input is a Control input, used to +// prevent hoisting the divide above an unsafe test. +const Type* DivHFNode::Value(PhaseGVN* phase) const { + // Either input is TOP ==> the result is TOP + const Type* t1 = phase->type(in(1)); + const Type* t2 = phase->type(in(2)); + if(t1 == Type::TOP) { return Type::TOP; } + if(t2 == Type::TOP) { return Type::TOP; } + + // Either input is BOTTOM ==> the result is the local BOTTOM + const Type* bot = bottom_type(); + if((t1 == bot) || (t2 == bot) || + (t1 == Type::BOTTOM) || (t2 == Type::BOTTOM)) { + return bot; + } + + // x/x == 1, we ignore 0/0. + // Note: if t1 and t2 are zero then result is NaN (JVMS page 213) + // Does not work for variables because of NaN's + if (in(1) == in(2) && t1->base() == Type::HalfFloatCon && + !g_isnan(t1->getf()) && g_isfinite(t1->getf()) && t1->getf() != 0.0) { // could be negative ZERO or NaN + return TypeH::ONE; + } + + if (t2 == TypeH::ONE) { + return t1; + } + + // If divisor is a constant and not zero, divide the numbers + if (t1->base() == Type::HalfFloatCon && + t2->base() == Type::HalfFloatCon && + t2->getf() != 0.0) { + // could be negative zero + return TypeH::make(t1->getf() / t2->getf()); + } + + // If the dividend is a constant zero + // Note: if t1 and t2 are zero then result is NaN (JVMS page 213) + // Test TypeHF::ZERO is not sufficient as it could be negative zero + + if (t1 == TypeH::ZERO && !g_isnan(t2->getf()) && t2->getf() != 0.0) { + return TypeH::ZERO; + } + + // If divisor or dividend is nan then result is nan. + if (g_isnan(t1->getf()) || g_isnan(t2->getf())) { + return TypeH::make(NAN); + } + + // Otherwise we give up all hope + return Type::HALF_FLOAT; +} + +//----------------------------------------------------------------------------- +// Dividing by self is 1. +// IF the divisor is 1, we are an identity on the dividend. +Node* DivHFNode::Identity(PhaseGVN* phase) { + return (phase->type( in(2) ) == TypeH::ONE) ? in(1) : this; +} + + +//------------------------------Idealize--------------------------------------- +Node* DivHFNode::Ideal(PhaseGVN* phase, bool can_reshape) { + if (in(0) != nullptr && remove_dead_region(phase, can_reshape)) return this; + // Don't bother trying to transform a dead node + if (in(0) != nullptr && in(0)->is_top()) { return nullptr; } + + const Type* t2 = phase->type(in(2)); + if (t2 == TypeH::ONE) { // Identity? + return nullptr; // Skip it + } + const TypeH* tf = t2->isa_half_float_constant(); + if(tf == nullptr) { return nullptr; } + if(tf->base() != Type::HalfFloatCon) { return nullptr; } + + // Check for out of range values + if(tf->is_nan() || !tf->is_finite()) { return nullptr; } + + // Get the value + float f = tf->getf(); + int exp; + + // Consider the following geometric progression series of POT(power of two) numbers. + // 0.5 x 2^0 = 0.5, 0.5 x 2^1 = 1.0, 0.5 x 2^2 = 2.0, 0.5 x 2^3 = 4.0 ... 0.5 x 2^n, + // In all the above cases, normalized mantissa returned by frexp routine will + // be exactly equal to 0.5 while exponent will be 0,1,2,3...n + // Perform division to multiplication transform only if divisor is a POT value. + if(frexp((double)f, &exp) != 0.5) { return nullptr; } + + // Limit the range of acceptable exponents + if(exp < -14 || exp > 15) { return nullptr; } + + // Since divisor is a POT number, hence its reciprocal will never + // overflow 11 bits precision range of Float16 + // value if exponent returned by frexp routine strictly lie + // within the exponent range of normal min(0x1.0P-14) and + // normal max(0x1.ffcP+15) values. + // Thus we can safely compute the reciprocal of divisor without + // any concerns about the precision loss and transform the division + // into a multiplication operation. + float reciprocal = ((float)1.0) / f; + + assert(frexp((double)reciprocal, &exp) == 0.5, "reciprocal should be power of 2"); + + // return multiplication by the reciprocal + return (new MulHFNode(in(1), phase->makecon(TypeH::make(reciprocal)))); +} //============================================================================= //------------------------------Value------------------------------------------ diff --git a/src/hotspot/share/opto/divnode.hpp b/src/hotspot/share/opto/divnode.hpp index b3eb97d3996..a926f4ab122 100644 --- a/src/hotspot/share/opto/divnode.hpp +++ b/src/hotspot/share/opto/divnode.hpp @@ -78,6 +78,20 @@ public: virtual uint ideal_reg() const { return Op_RegF; } }; + +//------------------------------DivHFNode-------------------------------------- +// Half float division +class DivHFNode : public Node { +public: + DivHFNode(Node* c, Node* dividend, Node* divisor) : Node(c, dividend, divisor) {} + virtual int Opcode() const; + virtual Node* Identity(PhaseGVN* phase); + virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); + virtual const Type* Value(PhaseGVN* phase) const; + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------DivDNode--------------------------------------- // Double division class DivDNode : public Node { diff --git a/src/hotspot/share/opto/escape.cpp b/src/hotspot/share/opto/escape.cpp index 2d74564533e..0ab911c18be 100644 --- a/src/hotspot/share/opto/escape.cpp +++ b/src/hotspot/share/opto/escape.cpp @@ -4611,6 +4611,7 @@ void ConnectionGraph::split_unique_types(GrowableArray &alloc_worklist, op == Op_StrEquals || op == Op_VectorizedHashCode || op == Op_StrIndexOf || op == Op_StrIndexOfChar || op == Op_SubTypeCheck || + op == Op_ReinterpretS2HF || BarrierSet::barrier_set()->barrier_set_c2()->is_gc_barrier_node(use))) { n->dump(); use->dump(); diff --git a/src/hotspot/share/opto/library_call.cpp b/src/hotspot/share/opto/library_call.cpp index 1e503ca6bef..6d4990787d7 100644 --- a/src/hotspot/share/opto/library_call.cpp +++ b/src/hotspot/share/opto/library_call.cpp @@ -24,6 +24,7 @@ #include "asm/macroAssembler.hpp" #include "ci/ciUtilities.inline.hpp" +#include "ci/ciSymbols.hpp" #include "classfile/vmIntrinsics.hpp" #include "compiler/compileBroker.hpp" #include "compiler/compileLog.hpp" @@ -530,7 +531,8 @@ bool LibraryCallKit::try_to_inline(int predicate) { case vmIntrinsics::_longBitsToDouble: case vmIntrinsics::_floatToFloat16: case vmIntrinsics::_float16ToFloat: return inline_fp_conversions(intrinsic_id()); - + case vmIntrinsics::_sqrt_float16: return inline_fp16_operations(intrinsic_id(), 1); + case vmIntrinsics::_fma_float16: return inline_fp16_operations(intrinsic_id(), 3); case vmIntrinsics::_floatIsFinite: case vmIntrinsics::_floatIsInfinite: case vmIntrinsics::_doubleIsFinite: @@ -8606,3 +8608,112 @@ bool LibraryCallKit::inline_blackhole() { return true; } + +Node* LibraryCallKit::unbox_fp16_value(const TypeInstPtr* float16_box_type, ciField* field, Node* box) { + const TypeInstPtr* box_type = _gvn.type(box)->isa_instptr(); + if (box_type == nullptr || box_type->instance_klass() != float16_box_type->instance_klass()) { + return nullptr; // box klass is not Float16 + } + + // Null check; get notnull casted pointer + Node* null_ctl = top(); + Node* not_null_box = null_check_oop(box, &null_ctl, true); + // If not_null_box is dead, only null-path is taken + if (stopped()) { + set_control(null_ctl); + return nullptr; + } + assert(not_null_box->bottom_type()->is_instptr()->maybe_null() == false, ""); + const TypePtr* adr_type = C->alias_type(field)->adr_type(); + Node* adr = basic_plus_adr(not_null_box, field->offset_in_bytes()); + return access_load_at(not_null_box, adr, adr_type, TypeInt::SHORT, T_SHORT, IN_HEAP); +} + +Node* LibraryCallKit::box_fp16_value(const TypeInstPtr* float16_box_type, ciField* field, Node* value) { + PreserveReexecuteState preexecs(this); + jvms()->set_should_reexecute(true); + + const TypeKlassPtr* klass_type = float16_box_type->as_klass_type(); + Node* klass_node = makecon(klass_type); + Node* box = new_instance(klass_node); + + Node* value_field = basic_plus_adr(box, field->offset_in_bytes()); + const TypePtr* value_adr_type = value_field->bottom_type()->is_ptr(); + + Node* field_store = _gvn.transform(access_store_at(box, + value_field, + value_adr_type, + value, + TypeInt::SHORT, + T_SHORT, + IN_HEAP)); + set_memory(field_store, value_adr_type); + return box; +} + +bool LibraryCallKit::inline_fp16_operations(vmIntrinsics::ID id, int num_args) { + if (!Matcher::match_rule_supported(Op_ReinterpretS2HF) || + !Matcher::match_rule_supported(Op_ReinterpretHF2S)) { + return false; + } + + const TypeInstPtr* box_type = _gvn.type(argument(0))->isa_instptr(); + if (box_type == nullptr || box_type->const_oop() == nullptr) { + return false; + } + + ciInstanceKlass* float16_klass = box_type->const_oop()->as_instance()->java_lang_Class_klass()->as_instance_klass(); + const TypeInstPtr* float16_box_type = TypeInstPtr::make_exact(TypePtr::NotNull, float16_klass); + ciField* field = float16_klass->get_field_by_name(ciSymbols::value_name(), + ciSymbols::short_signature(), + false); + assert(field != nullptr, ""); + + // Transformed nodes + Node* fld1 = nullptr; + Node* fld2 = nullptr; + Node* fld3 = nullptr; + switch(num_args) { + case 3: + fld3 = unbox_fp16_value(float16_box_type, field, argument(3)); + if (fld3 == nullptr) { + return false; + } + fld3 = _gvn.transform(new ReinterpretS2HFNode(fld3)); + // fall-through + case 2: + fld2 = unbox_fp16_value(float16_box_type, field, argument(2)); + if (fld2 == nullptr) { + return false; + } + fld2 = _gvn.transform(new ReinterpretS2HFNode(fld2)); + // fall-through + case 1: + fld1 = unbox_fp16_value(float16_box_type, field, argument(1)); + if (fld1 == nullptr) { + return false; + } + fld1 = _gvn.transform(new ReinterpretS2HFNode(fld1)); + break; + default: fatal("Unsupported number of arguments %d", num_args); + } + + Node* result = nullptr; + switch (id) { + // Unary operations + case vmIntrinsics::_sqrt_float16: + result = _gvn.transform(new SqrtHFNode(C, control(), fld1)); + break; + // Ternary operations + case vmIntrinsics::_fma_float16: + result = _gvn.transform(new FmaHFNode(fld1, fld2, fld3)); + break; + default: + fatal_unexpected_iid(id); + break; + } + result = _gvn.transform(new ReinterpretHF2SNode(result)); + set_result(box_fp16_value(float16_box_type, field, result)); + return true; +} + diff --git a/src/hotspot/share/opto/library_call.hpp b/src/hotspot/share/opto/library_call.hpp index df9c858c5de..51dda136bc1 100644 --- a/src/hotspot/share/opto/library_call.hpp +++ b/src/hotspot/share/opto/library_call.hpp @@ -295,6 +295,9 @@ class LibraryCallKit : public GraphKit { bool inline_onspinwait(); bool inline_fp_conversions(vmIntrinsics::ID id); bool inline_fp_range_check(vmIntrinsics::ID id); + bool inline_fp16_operations(vmIntrinsics::ID id, int num_args); + Node* unbox_fp16_value(const TypeInstPtr* box_class, ciField* field, Node* box); + Node* box_fp16_value(const TypeInstPtr* box_class, ciField* field, Node* value); bool inline_number_methods(vmIntrinsics::ID id); bool inline_bitshuffle_methods(vmIntrinsics::ID id); bool inline_compare_unsigned(vmIntrinsics::ID id); diff --git a/src/hotspot/share/opto/machnode.cpp b/src/hotspot/share/opto/machnode.cpp index 5ecff618a90..2ae2c5de381 100644 --- a/src/hotspot/share/opto/machnode.cpp +++ b/src/hotspot/share/opto/machnode.cpp @@ -46,6 +46,7 @@ intptr_t MachOper::constant() const { return 0x00; } relocInfo::relocType MachOper::constant_reloc() const { return relocInfo::none; } jdouble MachOper::constantD() const { ShouldNotReachHere(); } jfloat MachOper::constantF() const { ShouldNotReachHere(); } +jshort MachOper::constantH() const { ShouldNotReachHere(); } jlong MachOper::constantL() const { ShouldNotReachHere(); } TypeOopPtr *MachOper::oop() const { return nullptr; } int MachOper::ccode() const { return 0x00; } diff --git a/src/hotspot/share/opto/machnode.hpp b/src/hotspot/share/opto/machnode.hpp index 4ac91175f78..ac60e7b7312 100644 --- a/src/hotspot/share/opto/machnode.hpp +++ b/src/hotspot/share/opto/machnode.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -156,6 +156,7 @@ public: virtual jdouble constantD() const; virtual jfloat constantF() const; virtual jlong constantL() const; + virtual jshort constantH() const; virtual TypeOopPtr *oop() const; virtual int ccode() const; // A zero, default, indicates this value is not needed. diff --git a/src/hotspot/share/opto/matcher.cpp b/src/hotspot/share/opto/matcher.cpp index 3f111c4b0d1..1c9afae6510 100644 --- a/src/hotspot/share/opto/matcher.cpp +++ b/src/hotspot/share/opto/matcher.cpp @@ -2304,6 +2304,7 @@ bool Matcher::find_shared_visit(MStack& mstack, Node* n, uint opcode, bool& mem_ case Op_EncodeISOArray: case Op_FmaD: case Op_FmaF: + case Op_FmaHF: case Op_FmaVD: case Op_FmaVF: case Op_MacroLogicV: @@ -2476,6 +2477,7 @@ void Matcher::find_shared_post_visit(Node* n, uint opcode) { } case Op_FmaD: case Op_FmaF: + case Op_FmaHF: case Op_FmaVD: case Op_FmaVF: { // Restructure into a binary tree for Matching. diff --git a/src/hotspot/share/opto/mulnode.cpp b/src/hotspot/share/opto/mulnode.cpp index 4720f59d5af..58439ce3773 100644 --- a/src/hotspot/share/opto/mulnode.cpp +++ b/src/hotspot/share/opto/mulnode.cpp @@ -66,7 +66,8 @@ Node *MulNode::Ideal(PhaseGVN *phase, bool can_reshape) { // only valid for the actual Mul nodes. uint op = Opcode(); bool real_mul = (op == Op_MulI) || (op == Op_MulL) || - (op == Op_MulF) || (op == Op_MulD); + (op == Op_MulF) || (op == Op_MulD) || + (op == Op_MulHF); // Convert "(-a)*(-b)" into "a*b". if (real_mul && in1->is_Sub() && in2->is_Sub()) { @@ -121,7 +122,8 @@ Node *MulNode::Ideal(PhaseGVN *phase, bool can_reshape) { // constant, flatten the expression tree. if( t2->singleton() && // Right input is a constant? op != Op_MulF && // Float & double cannot reassociate - op != Op_MulD ) { + op != Op_MulD && + op != Op_MulHF) { if( t2 == Type::TOP ) return nullptr; Node *mul1 = in(1); #ifdef ASSERT @@ -535,10 +537,31 @@ Node* MulFNode::Ideal(PhaseGVN* phase, bool can_reshape) { Node* base = in(1); return new AddFNode(base, base); } - return MulNode::Ideal(phase, can_reshape); } +//============================================================================= +//------------------------------Ideal------------------------------------------ +// Check to see if we are multiplying by a constant 2 and convert to add, then try the regular MulNode::Ideal +Node* MulHFNode::Ideal(PhaseGVN* phase, bool can_reshape) { + const TypeH* t2 = phase->type(in(2))->isa_half_float_constant(); + + // x * 2 -> x + x + if (t2 != nullptr && t2->getf() == 2) { + Node* base = in(1); + return new AddHFNode(base, base); + } + return MulNode::Ideal(phase, can_reshape); +} + +// Compute the product type of two half float ranges into this node. +const Type* MulHFNode::mul_ring(const Type* t0, const Type* t1) const { + if (t0 == Type::HALF_FLOAT || t1 == Type::HALF_FLOAT) { + return Type::HALF_FLOAT; + } + return TypeH::make(t0->getf() * t1->getf()); +} + //============================================================================= //------------------------------mul_ring--------------------------------------- // Compute the product type of two double ranges into this node. @@ -1900,6 +1923,28 @@ const Type* FmaFNode::Value(PhaseGVN* phase) const { #endif } +//============================================================================= +//------------------------------Value------------------------------------------ +const Type* FmaHFNode::Value(PhaseGVN* phase) const { + const Type* t1 = phase->type(in(1)); + if (t1 == Type::TOP) { return Type::TOP; } + if (t1->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } + const Type* t2 = phase->type(in(2)); + if (t2 == Type::TOP) { return Type::TOP; } + if (t2->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } + const Type* t3 = phase->type(in(3)); + if (t3 == Type::TOP) { return Type::TOP; } + if (t3->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } +#ifndef __STDC_IEC_559__ + return Type::HALF_FLOAT; +#else + float f1 = t1->getf(); + float f2 = t2->getf(); + float f3 = t3->getf(); + return TypeH::make(fma(f1, f2, f3)); +#endif +} + //============================================================================= //------------------------------hash------------------------------------------- // Hash function for MulAddS2INode. Operation is commutative with commutative pairs. diff --git a/src/hotspot/share/opto/mulnode.hpp b/src/hotspot/share/opto/mulnode.hpp index a34f41d2362..bb572b9d9a2 100644 --- a/src/hotspot/share/opto/mulnode.hpp +++ b/src/hotspot/share/opto/mulnode.hpp @@ -143,6 +143,24 @@ public: virtual uint ideal_reg() const { return Op_RegF; } }; +//------------------------------MulHFNode--------------------------------------- +// Multiply 2 half floats +class MulHFNode : public MulNode { +public: + MulHFNode(Node* in1, Node* in2) : MulNode(in1, in2) {} + virtual int Opcode() const; + virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); + virtual const Type* mul_ring(const Type*, const Type*) const; + const Type* mul_id() const { return TypeH::ONE; } + const Type* add_id() const { return TypeH::ZERO; } + int add_opcode() const { return Op_AddHF; } + int mul_opcode() const { return Op_MulHF; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------MulDNode--------------------------------------- // Multiply 2 doubles class MulDNode : public MulNode { @@ -416,6 +434,17 @@ public: virtual const Type* Value(PhaseGVN* phase) const; }; +//------------------------------FmaHFNode------------------------------------- +// fused-multiply-add half-precision float +class FmaHFNode : public FmaNode { +public: + FmaHFNode(Node* in1, Node* in2, Node* in3) : FmaNode(in1, in2, in3) {} + virtual int Opcode() const; + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + virtual const Type* Value(PhaseGVN* phase) const; +}; + //------------------------------MulAddS2INode---------------------------------- // Multiply shorts into integers and add them. // Semantics: I_OUT = S1 * S2 + S3 * S4 diff --git a/src/hotspot/share/opto/node.cpp b/src/hotspot/share/opto/node.cpp index 583d0ba7601..0c185c682b3 100644 --- a/src/hotspot/share/opto/node.cpp +++ b/src/hotspot/share/opto/node.cpp @@ -1591,6 +1591,13 @@ jfloat Node::getf() const { return ((ConFNode*)this)->type()->is_float_constant()->getf(); } +// Get a half float constant from a ConstNode. +// Returns the constant if it is a float ConstNode +jshort Node::geth() const { + assert( Opcode() == Op_ConH, "" ); + return ((ConHNode*)this)->type()->is_half_float_constant()->geth(); +} + #ifndef PRODUCT // Call this from debugger: diff --git a/src/hotspot/share/opto/node.hpp b/src/hotspot/share/opto/node.hpp index 101006d9afc..a2ea0b4939f 100644 --- a/src/hotspot/share/opto/node.hpp +++ b/src/hotspot/share/opto/node.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2025, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 2024, Alibaba Group Holding Limited. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * @@ -58,6 +58,7 @@ class CallNode; class CallRuntimeNode; class CallStaticJavaNode; class CastFFNode; +class CastHHNode; class CastDDNode; class CastVVNode; class CastIINode; @@ -724,6 +725,7 @@ public: DEFINE_CLASS_ID(CastDD, ConstraintCast, 4) DEFINE_CLASS_ID(CastVV, ConstraintCast, 5) DEFINE_CLASS_ID(CastPP, ConstraintCast, 6) + DEFINE_CLASS_ID(CastHH, ConstraintCast, 7) DEFINE_CLASS_ID(CMove, Type, 3) DEFINE_CLASS_ID(SafePointScalarObject, Type, 4) DEFINE_CLASS_ID(DecodeNarrowPtr, Type, 5) @@ -908,6 +910,7 @@ public: DEFINE_CLASS_QUERY(CheckCastPP) DEFINE_CLASS_QUERY(CastII) DEFINE_CLASS_QUERY(CastLL) + DEFINE_CLASS_QUERY(CastFF) DEFINE_CLASS_QUERY(ConI) DEFINE_CLASS_QUERY(CastPP) DEFINE_CLASS_QUERY(ConstraintCast) @@ -1256,6 +1259,7 @@ public: intptr_t get_narrowcon() const; jdouble getd() const; jfloat getf() const; + jshort geth() const; // Nodes which are pinned into basic blocks virtual bool pinned() const { return false; } diff --git a/src/hotspot/share/opto/subnode.cpp b/src/hotspot/share/opto/subnode.cpp index 5c7ad22e221..de0d2f7c8d0 100644 --- a/src/hotspot/share/opto/subnode.cpp +++ b/src/hotspot/share/opto/subnode.cpp @@ -552,6 +552,24 @@ const Type* SubFPNode::Value(PhaseGVN* phase) const { //============================================================================= +//------------------------------sub-------------------------------------------- +// A subtract node differences its two inputs. +const Type* SubHFNode::sub(const Type* t1, const Type* t2) const { + // no folding if one of operands is infinity or NaN, do not do constant folding + if(g_isfinite(t1->getf()) && g_isfinite(t2->getf())) { + return TypeH::make(t1->getf() - t2->getf()); + } + else if(g_isnan(t1->getf())) { + return t1; + } + else if(g_isnan(t2->getf())) { + return t2; + } + else { + return Type::HALF_FLOAT; + } +} + //------------------------------Ideal------------------------------------------ Node *SubFNode::Ideal(PhaseGVN *phase, bool can_reshape) { const Type *t2 = phase->type( in(2) ); @@ -1989,6 +2007,15 @@ const Type* SqrtFNode::Value(PhaseGVN* phase) const { return TypeF::make( (float)sqrt( (double)f ) ); } +const Type* SqrtHFNode::Value(PhaseGVN* phase) const { + const Type* t1 = phase->type(in(1)); + if (t1 == Type::TOP) { return Type::TOP; } + if (t1->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } + float f = t1->getf(); + if (f < 0.0f) return Type::HALF_FLOAT; + return TypeH::make((float)sqrt((double)f)); +} + const Type* ReverseINode::Value(PhaseGVN* phase) const { const Type *t1 = phase->type( in(1) ); if (t1 == Type::TOP) { diff --git a/src/hotspot/share/opto/subnode.hpp b/src/hotspot/share/opto/subnode.hpp index ca21e628676..1d8aa058424 100644 --- a/src/hotspot/share/opto/subnode.hpp +++ b/src/hotspot/share/opto/subnode.hpp @@ -130,6 +130,18 @@ public: virtual uint ideal_reg() const { return Op_RegD; } }; +//------------------------------SubHFNode-------------------------------------- +// Subtract 2 half floats +class SubHFNode : public SubFPNode { +public: + SubHFNode(Node* in1, Node* in2) : SubFPNode(in1, in2) {} + virtual int Opcode() const; + virtual const Type* sub(const Type*, const Type*) const; + const Type* add_id() const { return TypeH::ZERO; } + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------CmpNode--------------------------------------- // Compare 2 values, returning condition codes (-1, 0 or 1). class CmpNode : public SubNode { @@ -528,6 +540,20 @@ public: virtual const Type* Value(PhaseGVN* phase) const; }; +//------------------------------SqrtHFNode------------------------------------- +// square root of a half-precision float +class SqrtHFNode : public Node { +public: + SqrtHFNode(Compile* C, Node* c, Node* in1) : Node(c, in1) { + init_flags(Flag_is_expensive); + C->add_expensive_node(this); + } + virtual int Opcode() const; + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + virtual const Type* Value(PhaseGVN* phase) const; +}; + //-------------------------------ReverseBytesINode-------------------------------- // reverse bytes of an integer class ReverseBytesINode : public Node { diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index ae95c2bb6d8..31aa1507d23 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -2607,7 +2607,7 @@ const Type* VLoopTypes::container_type(Node* n) const { // Float to half float conversion may be succeeded by a conversion from // half float to float, in such a case back propagation of narrow type (SHORT) // may not be possible. - if (n->Opcode() == Op_ConvF2HF) { + if (n->Opcode() == Op_ConvF2HF || n->Opcode() == Op_ReinterpretHF2S) { return TypeInt::SHORT; } // A narrow type of arithmetic operations will be determined by diff --git a/src/hotspot/share/opto/type.cpp b/src/hotspot/share/opto/type.cpp index 364654ec42f..db6070428ce 100644 --- a/src/hotspot/share/opto/type.cpp +++ b/src/hotspot/share/opto/type.cpp @@ -26,6 +26,7 @@ #include "ci/ciTypeFlow.hpp" #include "classfile/javaClasses.hpp" #include "classfile/symbolTable.hpp" +#include "classfile/vmSymbols.hpp" #include "compiler/compileLog.hpp" #include "libadt/dict.hpp" #include "memory/oopFactory.hpp" @@ -44,6 +45,7 @@ #include "utilities/checkedCast.hpp" #include "utilities/powerOfTwo.hpp" #include "utilities/stringUtils.hpp" +#include "runtime/stubRoutines.hpp" // Portions of code courtesy of Clifford Click @@ -104,6 +106,9 @@ const Type::TypeInfo Type::_type_info[Type::lastype] = { { Abio, T_ILLEGAL, "abIO", false, 0, relocInfo::none }, // Abio { Return_Address, T_ADDRESS, "return_address",false, Op_RegP, relocInfo::none }, // Return_Address { Memory, T_ILLEGAL, "memory", false, 0, relocInfo::none }, // Memory + { HalfFloatBot, T_SHORT, "halffloat_top", false, Op_RegF, relocInfo::none }, // HalfFloatTop + { HalfFloatCon, T_SHORT, "hfcon:", false, Op_RegF, relocInfo::none }, // HalfFloatCon + { HalfFloatTop, T_SHORT, "short", false, Op_RegF, relocInfo::none }, // HalfFloatBot { FloatBot, T_FLOAT, "float_top", false, Op_RegF, relocInfo::none }, // FloatTop { FloatCon, T_FLOAT, "ftcon:", false, Op_RegF, relocInfo::none }, // FloatCon { FloatTop, T_FLOAT, "float", false, Op_RegF, relocInfo::none }, // FloatBot @@ -133,6 +138,7 @@ const Type *Type::ABIO; // State-of-machine only const Type *Type::BOTTOM; // All values const Type *Type::CONTROL; // Control only const Type *Type::DOUBLE; // All doubles +const Type *Type::HALF_FLOAT; // All half floats const Type *Type::FLOAT; // All floats const Type *Type::HALF; // Placeholder half of doublewide type const Type *Type::MEMORY; // Abstract store only @@ -453,6 +459,7 @@ void Type::Initialize_shared(Compile* current) { ABIO = make(Abio); // State-of-machine only RETURN_ADDRESS=make(Return_Address); FLOAT = make(FloatBot); // All floats + HALF_FLOAT = make(HalfFloatBot); // All half floats DOUBLE = make(DoubleBot); // All doubles BOTTOM = make(Bottom); // Everything HALF = make(Half); // Placeholder half of doublewide type @@ -464,6 +471,13 @@ void Type::Initialize_shared(Compile* current) { TypeF::POS_INF = TypeF::make(jfloat_cast(POSITIVE_INFINITE_F)); TypeF::NEG_INF = TypeF::make(-jfloat_cast(POSITIVE_INFINITE_F)); + TypeH::MAX = TypeH::make(max_jfloat16); // HalfFloat MAX + TypeH::MIN = TypeH::make(min_jfloat16); // HalfFloat MIN + TypeH::ZERO = TypeH::make((jshort)0); // HalfFloat 0 (positive zero) + TypeH::ONE = TypeH::make(one_jfloat16); // HalfFloat 1 + TypeH::POS_INF = TypeH::make(pos_inf_jfloat16); + TypeH::NEG_INF = TypeH::make(neg_inf_jfloat16); + TypeD::MAX = TypeD::make(max_jdouble); // Double MAX TypeD::MIN = TypeD::make(min_jdouble); // Double MIN TypeD::ZERO = TypeD::make(0.0); // Double 0 (positive zero) @@ -1039,6 +1053,7 @@ const Type *Type::xmeet( const Type *t ) const { // Cut in half the number of cases I must handle. Only need cases for when // the given enum "t->type" is less than or equal to the local enum "type". + case HalfFloatCon: case FloatCon: case DoubleCon: case Int: @@ -1074,19 +1089,30 @@ const Type *Type::xmeet( const Type *t ) const { case Bottom: // Ye Olde Default return t; + case HalfFloatTop: + if (_base == HalfFloatTop) { return this; } + case HalfFloatBot: // Half Float + if (_base == HalfFloatBot || _base == HalfFloatTop) { return HALF_FLOAT; } + if (_base == FloatBot || _base == FloatTop) { return Type::BOTTOM; } + if (_base == DoubleTop || _base == DoubleBot) { return Type::BOTTOM; } + typerr(t); + return Type::BOTTOM; + case FloatTop: - if( _base == FloatTop ) return this; + if (_base == FloatTop ) { return this; } case FloatBot: // Float - if( _base == FloatBot || _base == FloatTop ) return FLOAT; - if( _base == DoubleTop || _base == DoubleBot ) return Type::BOTTOM; + if (_base == FloatBot || _base == FloatTop) { return FLOAT; } + if (_base == HalfFloatTop || _base == HalfFloatBot) { return Type::BOTTOM; } + if (_base == DoubleTop || _base == DoubleBot) { return Type::BOTTOM; } typerr(t); return Type::BOTTOM; case DoubleTop: - if( _base == DoubleTop ) return this; + if (_base == DoubleTop) { return this; } case DoubleBot: // Double - if( _base == DoubleBot || _base == DoubleTop ) return DOUBLE; - if( _base == FloatTop || _base == FloatBot ) return Type::BOTTOM; + if (_base == DoubleBot || _base == DoubleTop) { return DOUBLE; } + if (_base == HalfFloatTop || _base == HalfFloatBot) { return Type::BOTTOM; } + if (_base == FloatTop || _base == FloatBot) { return Type::BOTTOM; } typerr(t); return Type::BOTTOM; @@ -1094,7 +1120,7 @@ const Type *Type::xmeet( const Type *t ) const { case Control: // Control of code case Abio: // State of world outside of program case Memory: - if( _base == t->_base ) return this; + if (_base == t->_base) { return this; } typerr(t); return Type::BOTTOM; @@ -1174,6 +1200,7 @@ bool Type::empty(void) const { switch (_base) { case DoubleTop: case FloatTop: + case HalfFloatTop: case Top: return true; @@ -1182,6 +1209,7 @@ bool Type::empty(void) const { case Return_Address: case Memory: case Bottom: + case HalfFloatBot: case FloatBot: case DoubleBot: return false; // never a singleton, therefore never empty @@ -1229,6 +1257,9 @@ Type::Category Type::category() const { case Type::AryKlassPtr: case Type::Function: case Type::Return_Address: + case Type::HalfFloatTop: + case Type::HalfFloatCon: + case Type::HalfFloatBot: case Type::FloatTop: case Type::FloatCon: case Type::FloatBot: @@ -1334,6 +1365,9 @@ const Type *TypeF::xmeet( const Type *t ) const { case NarrowKlass: case Int: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case DoubleTop: case DoubleCon: case DoubleBot: @@ -1412,6 +1446,138 @@ bool TypeF::empty(void) const { return false; // always exactly a singleton } +//============================================================================= +// Convenience common pre-built types. +const TypeH* TypeH::MAX; // Half float max +const TypeH* TypeH::MIN; // Half float min +const TypeH* TypeH::ZERO; // Half float zero +const TypeH* TypeH::ONE; // Half float one +const TypeH* TypeH::POS_INF; // Half float positive infinity +const TypeH* TypeH::NEG_INF; // Half float negative infinity + +//------------------------------make------------------------------------------- +// Create a halffloat constant +const TypeH* TypeH::make(short f) { + return (TypeH*)(new TypeH(f))->hashcons(); +} + +const TypeH* TypeH::make(float f) { + assert(StubRoutines::f2hf_adr() != nullptr, ""); + short hf = StubRoutines::f2hf(f); + return (TypeH*)(new TypeH(hf))->hashcons(); +} + +//------------------------------xmeet------------------------------------------- +// Compute the MEET of two types. It returns a new Type object. +const Type* TypeH::xmeet(const Type* t) const { + // Perform a fast test for common case; meeting the same types together. + if (this == t) return this; // Meeting same type-rep? + + // Current "this->_base" is FloatCon + switch (t->base()) { // Switch on original type + case AnyPtr: // Mixing with oops happens when javac + case RawPtr: // reuses local variables + case OopPtr: + case InstPtr: + case AryPtr: + case MetadataPtr: + case KlassPtr: + case InstKlassPtr: + case AryKlassPtr: + case NarrowOop: + case NarrowKlass: + case Int: + case Long: + case FloatTop: + case FloatCon: + case FloatBot: + case DoubleTop: + case DoubleCon: + case DoubleBot: + case Bottom: // Ye Olde Default + return Type::BOTTOM; + + case HalfFloatBot: + return t; + + default: // All else is a mistake + typerr(t); + + case HalfFloatCon: // Half float-constant vs Half float-constant? + if (_f != t->geth()) { // unequal constants? + // must compare bitwise as positive zero, negative zero and NaN have + // all the same representation in C++ + return HALF_FLOAT; // Return generic float + } // Equal constants + case Top: + case HalfFloatTop: + break; // Return the Half float constant + } + return this; // Return the Half float constant +} + +//------------------------------xdual------------------------------------------ +// Dual: symmetric +const Type* TypeH::xdual() const { + return this; +} + +//------------------------------eq--------------------------------------------- +// Structural equality check for Type representations +bool TypeH::eq(const Type* t) const { + // Bitwise comparison to distinguish between +/-0. These values must be treated + // as different to be consistent with C1 and the interpreter. + return (_f == t->geth()); +} + +//------------------------------hash------------------------------------------- +// Type-specific hashing function. +uint TypeH::hash(void) const { + return *(jshort*)(&_f); +} + +//------------------------------is_finite-------------------------------------- +// Has a finite value +bool TypeH::is_finite() const { + assert(StubRoutines::hf2f_adr() != nullptr, ""); + float f = StubRoutines::hf2f(geth()); + return g_isfinite(f) != 0; +} + +float TypeH::getf() const { + assert(StubRoutines::hf2f_adr() != nullptr, ""); + return StubRoutines::hf2f(geth()); +} + +//------------------------------is_nan----------------------------------------- +// Is not a number (NaN) +bool TypeH::is_nan() const { + assert(StubRoutines::hf2f_adr() != nullptr, ""); + float f = StubRoutines::hf2f(geth()); + return g_isnan(f) != 0; +} + +//------------------------------dump2------------------------------------------ +// Dump float constant Type +#ifndef PRODUCT +void TypeH::dump2(Dict &d, uint depth, outputStream* st) const { + Type::dump2(d,depth, st); + st->print("%f", getf()); +} +#endif + +//------------------------------singleton-------------------------------------- +// TRUE if Type is a singleton type, FALSE otherwise. Singletons are simple +// constants (Ldi nodes). Singletons are integer, half float, float or double constants +// or a single symbol. +bool TypeH::singleton(void) const { + return true; // Always a singleton +} + +bool TypeH::empty(void) const { + return false; // always exactly a singleton +} + //============================================================================= // Convenience common pre-built types. const TypeD *TypeD::MAX; // Floating point max @@ -1447,6 +1613,9 @@ const Type *TypeD::xmeet( const Type *t ) const { case NarrowKlass: case Int: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -1643,6 +1812,9 @@ const Type *TypeInt::xmeet( const Type *t ) const { case NarrowOop: case NarrowKlass: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -1906,6 +2078,9 @@ const Type *TypeLong::xmeet( const Type *t ) const { case NarrowOop: case NarrowKlass: case Int: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -2700,6 +2875,9 @@ const Type *TypePtr::xmeet_helper(const Type *t) const { switch (t->base()) { // switch on original type case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -3639,6 +3817,9 @@ const Type *TypeOopPtr::xmeet_helper(const Type *t) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -4207,6 +4388,9 @@ const Type *TypeInstPtr::xmeet_helper(const Type *t) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -4884,6 +5068,9 @@ const Type *TypeAryPtr::xmeet_helper(const Type *t) const { // Mixing ints & oops happens when javac reuses local variables case Int: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -5313,6 +5500,9 @@ const Type *TypeNarrowPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -5468,6 +5658,9 @@ const Type *TypeMetadataPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -5842,6 +6035,9 @@ const Type *TypeInstKlassPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -6266,6 +6462,9 @@ const Type *TypeAryKlassPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: diff --git a/src/hotspot/share/opto/type.hpp b/src/hotspot/share/opto/type.hpp index 4b540345d27..67e339d3d2e 100644 --- a/src/hotspot/share/opto/type.hpp +++ b/src/hotspot/share/opto/type.hpp @@ -45,6 +45,7 @@ class Dict; class Type; class TypeD; class TypeF; +class TypeH; class TypeInteger; class TypeInt; class TypeLong; @@ -120,6 +121,9 @@ public: Abio, // Abstract I/O Return_Address, // Subroutine return address Memory, // Abstract store + HalfFloatTop, // No float value + HalfFloatCon, // Floating point constant + HalfFloatBot, // Any float value FloatTop, // No float value FloatCon, // Floating point constant FloatBot, // Any float value @@ -277,7 +281,8 @@ public: bool is_ptr_to_narrowklass() const; // Convenience access - float getf() const; + short geth() const; + virtual float getf() const; double getd() const; const TypeInt *is_int() const; @@ -289,6 +294,9 @@ public: const TypeD *isa_double() const; // Returns null if not a Double{Top,Con,Bot} const TypeD *is_double_constant() const; // Asserts it is a DoubleCon const TypeD *isa_double_constant() const; // Returns null if not a DoubleCon + const TypeH *isa_half_float() const; // Returns null if not a Float{Top,Con,Bot} + const TypeH *is_half_float_constant() const; // Asserts it is a FloatCon + const TypeH *isa_half_float_constant() const; // Returns null if not a FloatCon const TypeF *isa_float() const; // Returns null if not a Float{Top,Con,Bot} const TypeF *is_float_constant() const; // Asserts it is a FloatCon const TypeF *isa_float_constant() const; // Returns null if not a FloatCon @@ -431,6 +439,7 @@ public: static const Type *CONTROL; static const Type *DOUBLE; static const Type *FLOAT; + static const Type *HALF_FLOAT; static const Type *HALF; static const Type *MEMORY; static const Type *MULTI; @@ -521,6 +530,38 @@ public: #endif }; +// Class of Half Float-Constant Types. +class TypeH : public Type { + TypeH(short f) : Type(HalfFloatCon), _f(f) {}; +public: + virtual bool eq(const Type* t) const; + virtual uint hash() const; // Type specific hashing + virtual bool singleton(void) const; // TRUE if type is a singleton + virtual bool empty(void) const; // TRUE if type is vacuous +public: + const short _f; // Half Float constant + + static const TypeH* make(float f); + static const TypeH* make(short f); + + virtual bool is_finite() const; // Has a finite value + virtual bool is_nan() const; // Is not a number (NaN) + + virtual float getf() const; + virtual const Type* xmeet(const Type* t) const; + virtual const Type* xdual() const; // Compute dual right now. + // Convenience common pre-built types. + static const TypeH* MAX; + static const TypeH* MIN; + static const TypeH* ZERO; // positive zero only + static const TypeH* ONE; + static const TypeH* POS_INF; + static const TypeH* NEG_INF; +#ifndef PRODUCT + virtual void dump2(Dict &d, uint depth, outputStream* st) const; +#endif +}; + //------------------------------TypeD------------------------------------------ // Class of Double-Constant Types. class TypeD : public Type { @@ -1943,6 +1984,11 @@ inline float Type::getf() const { return ((TypeF*)this)->_f; } +inline short Type::geth() const { + assert(_base == HalfFloatCon, "Not a HalfFloatCon"); + return ((TypeH*)this)->_f; +} + inline double Type::getd() const { assert( _base == DoubleCon, "Not a DoubleCon" ); return ((TypeD*)this)->_d; @@ -1975,6 +2021,21 @@ inline const TypeLong *Type::isa_long() const { return ( _base == Long ? (TypeLong*)this : nullptr); } +inline const TypeH* Type::isa_half_float() const { + return ((_base == HalfFloatTop || + _base == HalfFloatCon || + _base == HalfFloatBot) ? (TypeH*)this : nullptr); +} + +inline const TypeH* Type::is_half_float_constant() const { + assert( _base == HalfFloatCon, "Not a HalfFloat" ); + return (TypeH*)this; +} + +inline const TypeH* Type::isa_half_float_constant() const { + return (_base == HalfFloatCon ? (TypeH*)this : nullptr); +} + inline const TypeF *Type::isa_float() const { return ((_base == FloatTop || _base == FloatCon || @@ -2164,7 +2225,8 @@ inline const TypeNarrowKlass* Type::make_narrowklass() const { } inline bool Type::is_floatingpoint() const { - if( (_base == FloatCon) || (_base == FloatBot) || + if( (_base == HalfFloatCon) || (_base == HalfFloatBot) || + (_base == FloatCon) || (_base == FloatBot) || (_base == DoubleCon) || (_base == DoubleBot) ) return true; return false; diff --git a/src/hotspot/share/utilities/globalDefinitions.hpp b/src/hotspot/share/utilities/globalDefinitions.hpp index 0f8dee4e739..4dc8ebfef4a 100644 --- a/src/hotspot/share/utilities/globalDefinitions.hpp +++ b/src/hotspot/share/utilities/globalDefinitions.hpp @@ -546,6 +546,11 @@ const jfloat min_jfloat = jfloat_cast(min_jintFloat); const jint max_jintFloat = (jint)(0x7f7fffff); const jfloat max_jfloat = jfloat_cast(max_jintFloat); +const jshort max_jfloat16 = 31743; +const jshort min_jfloat16 = 1; +const jshort one_jfloat16 = 15360; +const jshort pos_inf_jfloat16 = 31744; +const jshort neg_inf_jfloat16 = -1024; // A named constant for the integral representation of a Java null. const intptr_t NULL_WORD = 0; @@ -904,6 +909,7 @@ class JavaValue { void set_jfloat(jfloat f) { _value.f = f;} void set_jdouble(jdouble d) { _value.d = d;} void set_jint(jint i) { _value.i = i;} + void set_jshort(jshort i) { _value.i = i;} void set_jlong(jlong l) { _value.l = l;} void set_jobject(jobject h) { _value.h = h;} void set_oop(oopDesc* o) { _value.o = o;} diff --git a/src/java.base/share/classes/jdk/internal/vm/vector/Float16Math.java b/src/java.base/share/classes/jdk/internal/vm/vector/Float16Math.java new file mode 100644 index 00000000000..fc385975c18 --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/vm/vector/Float16Math.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.vm.vector; + +import jdk.internal.vm.annotation.IntrinsicCandidate; +import java.util.function.UnaryOperator; + +public class Float16Math { + + @FunctionalInterface + public interface TernaryOperator { + T apply(T a, T b, T c); + } + + @IntrinsicCandidate + public static T sqrt(Class box_class, T oa, UnaryOperator defaultImpl) { + assert isNonCapturingLambda(defaultImpl) : defaultImpl; + return defaultImpl.apply(oa); + } + + @IntrinsicCandidate + public static T fma(Class box_class, T oa, T ob, T oc, TernaryOperator defaultImpl) { + assert isNonCapturingLambda(defaultImpl) : defaultImpl; + return defaultImpl.apply(oa, ob, oc); + } + + public static boolean isNonCapturingLambda(Object o) { + return o.getClass().getDeclaredFields().length == 0; + } +} diff --git a/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java b/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java index f5f5a5a4e7e..f918878324d 100644 --- a/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java +++ b/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,6 +39,8 @@ import static java.lang.Float.float16ToFloat; import static java.lang.Float.floatToFloat16; import static java.lang.Integer.numberOfLeadingZeros; import static java.lang.Math.multiplyHigh; +import jdk.internal.vm.annotation.ForceInline; +import jdk.internal.vm.vector.Float16Math; /** * The {@code Float16} is a class holding 16-bit data @@ -321,6 +323,7 @@ public final class Float16 * * @param f a {@code float} */ + @ForceInline public static Float16 valueOf(float f) { return new Float16(floatToFloat16(f)); } @@ -764,6 +767,7 @@ public final class Float16 * @jls 5.1.3 Narrowing Primitive Conversion */ @Override + @ForceInline public byte byteValue() { return (byte)floatValue(); } @@ -785,6 +789,7 @@ public final class Float16 * @jls 5.1.3 Narrowing Primitive Conversion */ @Override + @ForceInline public short shortValue() { return (short)floatValue(); } @@ -800,6 +805,7 @@ public final class Float16 * @jls 5.1.3 Narrowing Primitive Conversion */ @Override + @ForceInline public int intValue() { return (int)floatValue(); } @@ -830,6 +836,7 @@ public final class Float16 * @jls 5.1.2 Widening Primitive Conversion */ @Override + @ForceInline public float floatValue() { return float16ToFloat(value); } @@ -845,6 +852,7 @@ public final class Float16 * @jls 5.1.2 Widening Primitive Conversion */ @Override + @ForceInline public double doubleValue() { return (double)floatValue(); } @@ -1191,12 +1199,16 @@ public final class Float16 * @see Math#sqrt(double) */ public static Float16 sqrt(Float16 radicand) { - // Rounding path of sqrt(Float16 -> double) -> Float16 is fine - // for preserving the correct final value. The conversion - // Float16 -> double preserves the exact numerical value. The - // conversion of double -> Float16 also benefits from the - // 2p+2 property of IEEE 754 arithmetic. - return valueOf(Math.sqrt(radicand.doubleValue())); + return Float16Math.sqrt(Float16.class, radicand, + (_radicand) -> { + // Rounding path of sqrt(Float16 -> double) -> Float16 is fine + // for preserving the correct final value. The conversion + // Float16 -> double preserves the exact numerical value. The + // conversion of double -> Float16 also benefits from the + // 2p+2 property of IEEE 754 arithmetic. + return valueOf(Math.sqrt(_radicand.doubleValue())); + } + ); } /** @@ -1398,11 +1410,14 @@ public final class Float16 * harmless. */ - // product is numerically exact in float before the cast to - // double; not necessary to widen to double before the - // multiply. - double product = (double)(a.floatValue() * b.floatValue()); - return valueOf(product + c.doubleValue()); + return Float16Math.fma(Float16.class, a, b, c, + (_a, _b, _c) -> { + // product is numerically exact in float before the cast to + // double; not necessary to widen to double before the + // multiply. + double product = (double)(_a.floatValue() * _b.floatValue()); + return valueOf(product + _c.doubleValue()); + }); } /** diff --git a/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java b/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java index 622102a3a67..d25f7e71990 100644 --- a/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java +++ b/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2009, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2009, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -257,6 +257,7 @@ public class AMD64 extends Architecture { AVX_IFMA, APX_F, SHA512, + AVX512_FP16, } private final EnumSet features; diff --git a/test/hotspot/jtreg/compiler/c2/irTests/ConvF2HFIdealizationTests.java b/test/hotspot/jtreg/compiler/c2/irTests/ConvF2HFIdealizationTests.java new file mode 100644 index 00000000000..1dbfcd44eb4 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/irTests/ConvF2HFIdealizationTests.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package compiler.c2.irTests; + +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; +import jdk.test.lib.Asserts; + +/* + * @test + * @bug 8338061 + * @summary Test that Ideal transformations of ConvF2HF are being performed as expected. + * @modules jdk.incubator.vector + * @library /test/lib / + * @run driver compiler.c2.irTests.ConvF2HFIdealizationTests + */ +public class ConvF2HFIdealizationTests { + private short[] sin; + private short[] sout; + private static final int SIZE = 65504; + public ConvF2HFIdealizationTests() { + sin = new short[SIZE]; + sout = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + sin[i] = Float.floatToFloat16((float)i); + } + } + public static void main(String[] args) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector", "-XX:-UseSuperWord"); + } + + @Test + @IR(counts = {IRNode.REINTERPRET_S2HF, ">=1", IRNode.REINTERPRET_HF2S, ">=1", IRNode.ADD_HF, ">=1" }, + failOn = {IRNode.ADD_F, IRNode.CONV_HF2F, IRNode.CONV_F2HF}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + // Test pattern - ConvHF2F -> AddF -> ConvF2HF is optimized to ReinterpretS2HF -> AddHF -> ReinterpretHF2S + public void test1() { + for (int i = 0; i < SIZE; i++) { + sout[i] = Float.floatToFloat16(Float.float16ToFloat(sin[i]) + Float.float16ToFloat(sin[i])); + } + } + + @Check(test="test1") + public void checkResult() { + for (int i = 0; i < SIZE; i++) { + short expected = Float16.float16ToRawShortBits(Float16.add(Float16.shortBitsToFloat16(sin[i]), Float16.shortBitsToFloat16(sin[i]))); + if (expected != sout[i]) { + throw new RuntimeException("Invalid result: sout[" + i + "] = " + sout[i] + " != " + expected); + } + } + } +} diff --git a/test/hotspot/jtreg/compiler/c2/irTests/MulHFNodeIdealizationTests.java b/test/hotspot/jtreg/compiler/c2/irTests/MulHFNodeIdealizationTests.java new file mode 100644 index 00000000000..dd98c80d629 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/irTests/MulHFNodeIdealizationTests.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package compiler.c2.irTests; + +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; +import java.util.Random; +import jdk.test.lib.Asserts; + +/* + * @test + * @bug 8336406 + * @summary Test that Ideal transformations of MulHFNode are being performed as expected. + * @modules jdk.incubator.vector + * @library /test/lib / + * @run driver compiler.c2.irTests.MulHFNodeIdealizationTests + */ +public class MulHFNodeIdealizationTests { + + private Float16 src; + private Float16 dst; + private Random rng; + + public static void main(String[] args) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector"); + } + + public MulHFNodeIdealizationTests() { + rng = new Random(25); + src = valueOf(rng.nextFloat()); + dst = valueOf(rng.nextFloat()); + } + + @Test + @IR(counts = {IRNode.ADD_HF, "1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}, + failOn = {IRNode.MUL_HF}) + public void test1() { + dst = multiply(src, valueOf(2.0f)); + } + + @Check(test="test1") + public void checkTest1() { + Float16 expected = valueOf(src.floatValue() * 2.0f); + if (float16ToRawShortBits(expected) != float16ToRawShortBits(dst)) { + throw new RuntimeException("Invalid result: dst = " + float16ToRawShortBits(dst) + " != " + float16ToRawShortBits(expected)); + } + } +} diff --git a/test/hotspot/jtreg/compiler/c2/irTests/TestFloat16ScalarOperations.java b/test/hotspot/jtreg/compiler/c2/irTests/TestFloat16ScalarOperations.java new file mode 100644 index 00000000000..17a3e4b4c56 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/irTests/TestFloat16ScalarOperations.java @@ -0,0 +1,585 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2025, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** +* @test +* @bug 8308363 8336406 +* @summary Validate compiler IR for various Float16 scalar operations. +* @modules jdk.incubator.vector +* @requires vm.compiler2.enabled +* @library /test/lib / +* @run driver TestFloat16ScalarOperations +*/ +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; +import java.util.Random; + +public class TestFloat16ScalarOperations { + private static final int count = 1024; + + private short[] src; + private short[] dst; + private short res; + + private static final Float16 ONE = valueOf(1.0f); + private static final Float16 MONE = valueOf(-1.0f); + private static final Float16 POSITIVE_ZERO = valueOf(0.0f); + private static final Float16 NEGATIVE_ZERO = valueOf(-0.0f); + private static final Float16 MIN_NORMAL = valueOf(0x1.0P-14f); + private static final Float16 NEGATIVE_MAX_VALUE = valueOf(-0x1.ffcP+15f); + private static final Float16 LT_MAX_HALF_ULP = Float16.valueOf(14.0f); + private static final Float16 MAX_HALF_ULP = Float16.valueOf(16.0f); + private static final Float16 SIGNALING_NAN = shortBitsToFloat16((short)31807); + + private static Random r = jdk.test.lib.Utils.getRandomInstance(); + + private static final Float16 RANDOM1 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue()); + private static final Float16 RANDOM2 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue()); + private static final Float16 RANDOM3 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue()); + private static final Float16 RANDOM4 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue()); + private static final Float16 RANDOM5 = Float16.valueOf(r.nextFloat() * MAX_VALUE.floatValue()); + + private static Float16 RANDOM1_VAR = RANDOM1; + private static Float16 RANDOM2_VAR = RANDOM2; + private static Float16 RANDOM3_VAR = RANDOM3; + private static Float16 RANDOM4_VAR = RANDOM4; + private static Float16 RANDOM5_VAR = RANDOM5; + + public static void main(String args[]) { + Scenario s0 = new Scenario(0, "--add-modules=jdk.incubator.vector", "-Xint"); + Scenario s1 = new Scenario(1, "--add-modules=jdk.incubator.vector"); + new TestFramework().addScenarios(s1).start(); + } + + public TestFloat16ScalarOperations() { + src = new short[count]; + dst = new short[count]; + for (int i = 0; i < count; i++) { + src[i] = Float.floatToFloat16(r.nextFloat() * MAX_VALUE.floatValue()); + } + } + + static void assertResult(float actual, float expected, String msg) { + if (actual != expected) { + if (!Float.isNaN(actual) || !Float.isNaN(expected)) { + String error = "TEST : " + msg + ": actual(" + actual + ") != expected(" + expected + ")"; + throw new AssertionError(error); + } + } + } + + static void assertResult(float actual, float expected, String msg, int iter) { + if (actual != expected) { + if (!Float.isNaN(actual) || !Float.isNaN(expected)) { + String error = "TEST (" + iter + "): " + msg + ": actual(" + actual + ") != expected(" + expected + ")"; + throw new AssertionError(error); + } + } + } + + @Test + @IR(counts = {"convHF2SAndHF2F", " >0 "}, phase = {CompilePhase.FINAL_CODE}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testEliminateIntermediateHF2S() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + // Intermediate HF2S + S2HF is eliminated in following transformation + // AddHF S2HF(HF2S (AddHF S2HF(src[i]), S2HF(0))), S2HF(src[i]) => AddHF (AddHF S2HF(src[i]), S2HF(0)), S2HF(src[i]) + res = add(add(res, shortBitsToFloat16(src[i])), shortBitsToFloat16(src[i])); + dst[i] = (short)res.floatValue(); + } + } + + @Test + @IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testAdd1() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.add(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(failOn = {IRNode.ADD_HF, IRNode.REINTERPRET_S2HF, IRNode.REINTERPRET_HF2S}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testAdd2() { + Float16 hf0 = shortBitsToFloat16((short)0); + Float16 hf1 = shortBitsToFloat16((short)15360); + Float16 hf2 = shortBitsToFloat16((short)16384); + Float16 hf3 = shortBitsToFloat16((short)16896); + Float16 hf4 = shortBitsToFloat16((short)17408); + res = float16ToRawShortBits(Float16.add(Float16.add(Float16.add(Float16.add(hf0, hf1), hf2), hf3), hf4)); + } + + @Test + @IR(counts = {IRNode.SUB_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSub() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.subtract(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MUL_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMul() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.multiply(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.DIV_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDiv() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.divide(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.DIV_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDivByOne() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.divide(shortBitsToFloat16(src[i]), ONE); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MAX_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMax() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.max(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MIN_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMin() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.min(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.SQRT_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSqrt() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.sqrt(shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.FMA_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testFma() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + Float16 in = shortBitsToFloat16(src[i]); + res = Float16.fma(in, in, in); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MUL_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDivByPOT() { + Float16 res = valueOf(0.0f); + for (int i = 0; i < 50; i++) { + Float16 divisor = valueOf(8.0f); + Float16 dividend = shortBitsToFloat16(src[i]); + res = add(res, divide(dividend, divisor)); + divisor = valueOf(16.0f); + res = add(res, divide(dividend, divisor)); + divisor = valueOf(32.0f); + res = add(res, divide(dividend, divisor)); + } + dst[0] = float16ToRawShortBits(res); + } + + @Test + @IR(counts = {IRNode.MUL_HF, " 0 ", IRNode.ADD_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMulByTWO() { + Float16 res = valueOf(0.0f); + Float16 multiplier = valueOf(2.0f); + for (int i = 0; i < 20; i++) { + Float16 multiplicand = valueOf((float)i); + res = add(res, multiply(multiplicand, multiplier)); + } + assertResult(res.floatValue(), (float)((20 * (20 - 1))/2) * 2.0f, "testMulByTWO"); + } + + + // + // Tests points for various Float16 constant folding transforms. Following figure represents various + // special IEEE 754 binary16 values on a number line + // + // -Inf -0.0 Inf + // -------|-----------------------------|----------------------------|------ + // -MAX_VALUE 0.0 MAX_VALUE + // + // Number whose exponent lie between -14 and 15, both values inclusive, belongs to normal value range. + // IEEE 754 binary16 specification allows graceful degradation of numbers with exponents less than -14 + // into a sub-normal value range i.e. their exponents may extend uptill -24, this is because format + // supports 10 mantissa bits which can be used to represent a number with exponents less than -14. + // + // A number below the sub-normal value range is considered as 0.0. With regards to overflowing + // semantics, a value equal to or greater than MAX_VALUE + half ulp (MAX_VALUE) is considered as + // an Infinite value on both side of axis. + // + // In addition, format specifies special bit representation for +Inf, -Inf and NaN values. + // + // Tests also covers special cases for various operations as per Java SE specification. + // + + + @Test + @IR(counts = {IRNode.ADD_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testAddConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(add(Float16.NaN, valueOf(2.0f)).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(add(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(add(Float16.NaN, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The sum of two infinities of opposite sign is NaN. + assertResult(add(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The sum of two infinities of the same sign is the infinity of that sign. + assertResult(add(Float16.POSITIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + assertResult(add(Float16.NEGATIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // The sum of an infinity and a finite value is equal to the infinite operand. + assertResult(add(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + assertResult(add(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // The sum of two zeros of opposite sign is positive zero. + assertResult(add(NEGATIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testAddConstantFolding"); + + // The sum of two zeros of the same sign is the zero of that sign. + assertResult(add(NEGATIVE_ZERO, NEGATIVE_ZERO).floatValue(), -0.0f, "testAddConstantFolding"); + + // The sum of a zero and a nonzero finite value is equal to the nonzero operand. + assertResult(add(POSITIVE_ZERO, valueOf(2.0f)).floatValue(), 2.0f, "testAddConstantFolding"); + assertResult(add(NEGATIVE_ZERO, valueOf(2.0f)).floatValue(), 2.0f, "testAddConstantFolding"); + + // Number equal to MAX_VALUE when added to half upl for MAX_VALUE results into Inf. + assertResult(add(Float16.MAX_VALUE, MAX_HALF_ULP).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + + // If the magnitude of the sum is too large to represent, we say the operation + // overflows; the result is then an infinity of appropriate sign. + assertResult(add(Float16.MAX_VALUE, Float16.MAX_VALUE).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + + // Number equal to MAX_VALUE when added to half upl for MAX_VALUE results into MAX_VALUE. + assertResult(add(Float16.MAX_VALUE, LT_MAX_HALF_ULP).floatValue(), Float16.MAX_VALUE.floatValue(), "testAddConstantFolding"); + + assertResult(add(valueOf(1.0f), valueOf(2.0f)).floatValue(), 3.0f, "testAddConstantFolding"); + } + + @Test + @IR(counts = {IRNode.SUB_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSubConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(subtract(Float16.NaN, valueOf(2.0f)).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(subtract(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(subtract(Float16.NaN, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The difference of two infinities of opposite sign is NaN. + assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + + // The difference of two infinities of the same sign is NaN. + assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(subtract(Float16.NEGATIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The difference of an infinity and a finite value is equal to the infinite operand. + assertResult(subtract(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + assertResult(subtract(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // The difference of two zeros of opposite sign is positive zero. + assertResult(subtract(NEGATIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testAddConstantFolding"); + + // Number equal to -MAX_VALUE when subtracted by half upl of MAX_VALUE results into -Inf. + assertResult(subtract(NEGATIVE_MAX_VALUE, MAX_HALF_ULP).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // Number equal to -MAX_VALUE when subtracted by a number less than half upl for MAX_VALUE results into -MAX_VALUE. + assertResult(subtract(NEGATIVE_MAX_VALUE, LT_MAX_HALF_ULP).floatValue(), NEGATIVE_MAX_VALUE.floatValue(), "testAddConstantFolding"); + + assertResult(subtract(valueOf(1.0f), valueOf(2.0f)).floatValue(), -1.0f, "testAddConstantFolding"); + } + + @Test + @Warmup(value = 10000) + @IR(counts = {IRNode.MAX_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMaxConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(max(valueOf(2.0f), Float16.NaN).floatValue(), Float.NaN, "testMaxConstantFolding"); + assertResult(max(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testMaxConstantFolding"); + + // This operation considers negative zero to be strictly smaller than positive zero + assertResult(max(POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), 0.0f, "testMaxConstantFolding"); + + // Other cases. + assertResult(max(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testMaxConstantFolding"); + assertResult(max(valueOf(1.0f), valueOf(2.0f)).floatValue(), 2.0f, "testMaxConstantFolding"); + assertResult(max(Float16.MAX_VALUE, Float16.MIN_VALUE).floatValue(), Float16.MAX_VALUE.floatValue(), "testMaxConstantFolding"); + } + + + @Test + @IR(counts = {IRNode.MIN_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMinConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(min(valueOf(2.0f), Float16.NaN).floatValue(), Float.NaN, "testMinConstantFolding"); + assertResult(min(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testMinConstantFolding"); + + // This operation considers negative zero to be strictly smaller than positive zero + assertResult(min(POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), -0.0f, "testMinConstantFolding"); + + // Other cases. + assertResult(min(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NEGATIVE_INFINITY, "testMinConstantFolding"); + assertResult(min(valueOf(1.0f), valueOf(2.0f)).floatValue(), 1.0f, "testMinConstantFolding"); + assertResult(min(Float16.MAX_VALUE, Float16.MIN_VALUE).floatValue(), Float16.MIN_VALUE.floatValue(), "testMinConstantFolding"); + } + + @Test + @IR(counts = {IRNode.DIV_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDivConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(divide(Float16.NaN, POSITIVE_ZERO).floatValue(), Float.NaN, "testDivConstantFolding"); + assertResult(divide(NEGATIVE_ZERO, Float16.NaN).floatValue(), Float.NaN, "testDivConstantFolding"); + + // Division of an infinity by an infinity results in NaN. + assertResult(divide(Float16.NEGATIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testDivConstantFolding"); + + // Division of an infinity by a finite value results in a signed infinity. Sign of the result is positive if both operands have + // the same sign, and negative if the operands have different signs + assertResult(divide(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testDivConstantFolding"); + assertResult(divide(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testDivConstantFolding"); + + // Division of a finite value by an infinity results in a signed zero. The sign is + // determined by the above rule. + assertResult(divide(valueOf(2.0f), Float16.POSITIVE_INFINITY).floatValue(), 0.0f, "testDivConstantFolding"); + assertResult(divide(valueOf(2.0f), Float16.NEGATIVE_INFINITY).floatValue(), -0.0f, "testDivConstantFolding"); + + // Division of a zero by a zero results in NaN; division of zero by any other finite + // value results in a signed zero. The sign is determined by the rule stated above. + assertResult(divide(POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), Float.NaN, "testDivConstantFolding"); + assertResult(divide(POSITIVE_ZERO, Float16.MAX_VALUE).floatValue(), 0.0f, "testDivConstantFolding"); + assertResult(divide(NEGATIVE_ZERO, Float16.MAX_VALUE).floatValue(), -0.0f, "testDivConstantFolding"); + + // Division of a nonzero finite value by a zero results in a signed infinity. The sign + // is determined by the rule stated above + assertResult(divide(valueOf(2.0f), NEGATIVE_ZERO).floatValue(), Float.NEGATIVE_INFINITY, "testDivConstantFolding"); + assertResult(divide(valueOf(2.0f), POSITIVE_ZERO).floatValue(), Float.POSITIVE_INFINITY, "testDivConstantFolding"); + + // If the magnitude of the quotient is too large to represent, we say the operation + // overflows; the result is then an infinity of appropriate sign. + assertResult(divide(Float16.MAX_VALUE, Float16.MIN_NORMAL).floatValue(), Float.POSITIVE_INFINITY, "testDivConstantFolding"); + assertResult(divide(Float16.MAX_VALUE, valueOf(-0x1.0P-14f)).floatValue(), Float.NEGATIVE_INFINITY, "testDivConstantFolding"); + + assertResult(divide(valueOf(2.0f), valueOf(2.0f)).floatValue(), 1.0f, "testDivConstantFolding"); + } + + @Test + @IR(counts = {IRNode.MUL_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMulConstantFolding() { + // If any operand is NaN, the result is NaN. + assertResult(multiply(Float16.NaN, valueOf(4.0f)).floatValue(), Float.NaN, "testMulConstantFolding"); + assertResult(multiply(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testMulConstantFolding"); + + // Multiplication of an infinity by a zero results in NaN. + assertResult(multiply(Float16.POSITIVE_INFINITY, POSITIVE_ZERO).floatValue(), Float.NaN, "testMulConstantFolding"); + + // Multiplication of an infinity by a finite value results in a signed infinity. + assertResult(multiply(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testMulConstantFolding"); + assertResult(multiply(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testMulConstantFolding"); + + // If the magnitude of the product is too large to represent, we say the operation + // overflows; the result is then an infinity of appropriate sign + assertResult(multiply(Float16.MAX_VALUE, Float16.MAX_VALUE).floatValue(), Float.POSITIVE_INFINITY, "testMulConstantFolding"); + assertResult(multiply(NEGATIVE_MAX_VALUE, Float16.MAX_VALUE).floatValue(), Float.NEGATIVE_INFINITY, "testMulConstantFolding"); + + assertResult(multiply(multiply(multiply(valueOf(1.0f), valueOf(2.0f)), valueOf(3.0f)), valueOf(4.0f)).floatValue(), 1.0f * 2.0f * 3.0f * 4.0f, "testMulConstantFolding"); + } + + @Test + @IR(counts = {IRNode.SQRT_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSqrtConstantFolding() { + // If the argument is NaN or less than zero, then the result is NaN. + assertResult(sqrt(Float16.NaN).floatValue(), Float.NaN, "testSqrtConstantFolding"); + assertResult(sqrt(SIGNALING_NAN).floatValue(), Float.NaN, "testSqrtConstantFolding"); + + // If the argument is positive infinity, then the result is positive infinity. + assertResult(sqrt(Float16.POSITIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testSqrtConstantFolding"); + + // If the argument is positive zero or negative zero, then the result is the same as the argument. + assertResult(sqrt(POSITIVE_ZERO).floatValue(), 0.0f, "testSqrtConstantFolding"); + assertResult(sqrt(NEGATIVE_ZERO).floatValue(), -0.0f, "testSqrtConstantFolding"); + + // Other cases. + assertResult(Math.round(sqrt(valueOf(0x1.ffcP+14f)).floatValue()), Math.round(Math.sqrt(0x1.ffcP+14f)), "testSqrtConstantFolding"); + } + + @Test + @IR(counts = {IRNode.FMA_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testFMAConstantFolding() { + // If any argument is NaN, the result is NaN. + assertResult(fma(Float16.NaN, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(SIGNALING_NAN, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(valueOf(2.0f), Float16.NaN, valueOf(3.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + + assertResult(fma(shortBitsToFloat16(Float.floatToFloat16(2.0f)), + shortBitsToFloat16(Float.floatToFloat16(3.0f)), + Float16.NaN).floatValue(), Float.NaN, "testFMAConstantFolding"); + + // If one of the first two arguments is infinite and the other is zero, the result is NaN. + assertResult(fma(Float16.POSITIVE_INFINITY, POSITIVE_ZERO, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(Float16.POSITIVE_INFINITY, NEGATIVE_ZERO, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(NEGATIVE_ZERO, Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(POSITIVE_ZERO, Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + + // If the exact product of the first two arguments is infinite (in other words, at least one of the arguments is infinite + // and the other is neither zero nor NaN) and the third argument is an infinity of the opposite sign, the result is NaN. + assertResult(fma(valueOf(2.0f), Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(valueOf(2.0f), Float16.NEGATIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(Float16.POSITIVE_INFINITY, valueOf(2.0f), Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(Float16.NEGATIVE_INFINITY, valueOf(2.0f), Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + + // Signed bits. + assertResult(fma(NEGATIVE_ZERO, POSITIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testFMAConstantFolding"); + assertResult(fma(NEGATIVE_ZERO, POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), -0.0f, "testFMAConstantFolding"); + + assertResult(fma(Float16.POSITIVE_INFINITY, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.POSITIVE_INFINITY, "testFMAConstantFolding"); + assertResult(fma(Float16.NEGATIVE_INFINITY, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testFMAConstantFolding"); + assertResult(fma(valueOf(1.0f), valueOf(2.0f), valueOf(3.0f)).floatValue(), 1.0f * 2.0f + 3.0f, "testFMAConstantFolding"); + } + + @Test + @IR(failOn = {IRNode.ADD_HF, IRNode.SUB_HF, IRNode.MUL_HF, IRNode.DIV_HF, IRNode.SQRT_HF, IRNode.FMA_HF}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testRounding1() { + dst[0] = float16ToRawShortBits(add(RANDOM1, RANDOM2)); + dst[1] = float16ToRawShortBits(subtract(RANDOM2, RANDOM3)); + dst[2] = float16ToRawShortBits(multiply(RANDOM4, RANDOM5)); + dst[3] = float16ToRawShortBits(sqrt(RANDOM2)); + dst[4] = float16ToRawShortBits(fma(RANDOM3, RANDOM4, RANDOM5)); + dst[5] = float16ToRawShortBits(divide(RANDOM5, RANDOM4)); + } + + @Check(test = "testRounding1", when = CheckAt.COMPILED) + public void checkRounding1() { + assertResult(dst[0], Float.floatToFloat16(RANDOM1.floatValue() + RANDOM2.floatValue()), + "testRounding1 case1a"); + assertResult(dst[0], float16ToRawShortBits(add(RANDOM1, RANDOM2)), "testRounding1 case1b"); + + assertResult(dst[1], Float.floatToFloat16(RANDOM2.floatValue() - RANDOM3.floatValue()), + "testRounding1 case2a"); + assertResult(dst[1], float16ToRawShortBits(subtract(RANDOM2, RANDOM3)), "testRounding1 case2b"); + + assertResult(dst[2], Float.floatToFloat16(RANDOM4.floatValue() * RANDOM5.floatValue()), + "testRounding1 case3a"); + assertResult(dst[2], float16ToRawShortBits(multiply(RANDOM4, RANDOM5)), "testRounding1 cast3b"); + + assertResult(dst[3], Float.floatToFloat16((float)Math.sqrt(RANDOM2.floatValue())), "testRounding1 case4a"); + assertResult(dst[3], float16ToRawShortBits(sqrt(RANDOM2)), "testRounding1 case4a"); + + assertResult(dst[4], Float.floatToFloat16(Math.fma(RANDOM3.floatValue(), RANDOM4.floatValue(), + RANDOM5.floatValue())), "testRounding1 case5a"); + assertResult(dst[4], float16ToRawShortBits(fma(RANDOM3, RANDOM4, RANDOM5)), "testRounding1 case5b"); + + assertResult(dst[5], Float.floatToFloat16(RANDOM5.floatValue() / RANDOM4.floatValue()), + "testRounding1 case6a"); + assertResult(dst[5], float16ToRawShortBits(divide(RANDOM5, RANDOM4)), "testRounding1 case6b"); + } + + @Test + @IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.SUB_HF, " >0 ", IRNode.MUL_HF, " >0 ", + IRNode.DIV_HF, " >0 ", IRNode.SQRT_HF, " >0 ", IRNode.FMA_HF, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testRounding2() { + dst[0] = float16ToRawShortBits(add(RANDOM1_VAR, RANDOM2_VAR)); + dst[1] = float16ToRawShortBits(subtract(RANDOM2_VAR, RANDOM3_VAR)); + dst[2] = float16ToRawShortBits(multiply(RANDOM4_VAR, RANDOM5_VAR)); + dst[3] = float16ToRawShortBits(sqrt(RANDOM2_VAR)); + dst[4] = float16ToRawShortBits(fma(RANDOM3_VAR, RANDOM4_VAR, RANDOM5_VAR)); + dst[5] = float16ToRawShortBits(divide(RANDOM5_VAR, RANDOM4_VAR)); + } + + @Check(test = "testRounding2", when = CheckAt.COMPILED) + public void checkRounding2() { + assertResult(dst[0], Float.floatToFloat16(RANDOM1_VAR.floatValue() + RANDOM2_VAR.floatValue()), + "testRounding2 case1a"); + assertResult(dst[0], float16ToRawShortBits(add(RANDOM1_VAR, RANDOM2_VAR)), "testRounding2 case1b"); + + assertResult(dst[1], Float.floatToFloat16(RANDOM2_VAR.floatValue() - RANDOM3_VAR.floatValue()), + "testRounding2 case2a"); + assertResult(dst[1], float16ToRawShortBits(subtract(RANDOM2_VAR, RANDOM3_VAR)), "testRounding2 case2b"); + + assertResult(dst[2], Float.floatToFloat16(RANDOM4_VAR.floatValue() * RANDOM5_VAR.floatValue()), + "testRounding2 case3a"); + assertResult(dst[2], float16ToRawShortBits(multiply(RANDOM4_VAR, RANDOM5_VAR)), "testRounding2 cast3b"); + + assertResult(dst[3], Float.floatToFloat16((float)Math.sqrt(RANDOM2_VAR.floatValue())), "testRounding2 case4a"); + assertResult(dst[3], float16ToRawShortBits(sqrt(RANDOM2_VAR)), "testRounding2 case4a"); + + assertResult(dst[4], Float.floatToFloat16(Math.fma(RANDOM3_VAR.floatValue(), RANDOM4_VAR.floatValue(), + RANDOM5_VAR.floatValue())), "testRounding2 case5a"); + assertResult(dst[4], float16ToRawShortBits(fma(RANDOM3_VAR, RANDOM4_VAR, RANDOM5_VAR)), "testRounding2 case5b"); + + assertResult(dst[5], Float.floatToFloat16(RANDOM5_VAR.floatValue() / RANDOM4_VAR.floatValue()), + "testRounding2 case6a"); + assertResult(dst[5], float16ToRawShortBits(divide(RANDOM5_VAR, RANDOM4_VAR)), "testRounding2 case6b"); + } +} diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java index bfbfd80139c..8f28294a986 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java @@ -209,6 +209,11 @@ public class IRNode { beforeMatchingNameRegex(ADD, "Add(I|L|F|D|P)"); } + public static final String ADD_F = PREFIX + "ADD_F" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_F, "AddF"); + } + public static final String ADD_I = PREFIX + "ADD_I" + POSTFIX; static { beforeMatchingNameRegex(ADD_I, "AddI"); @@ -219,6 +224,11 @@ public class IRNode { beforeMatchingNameRegex(ADD_L, "AddL"); } + public static final String ADD_HF = PREFIX + "ADD_HF" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_HF, "AddHF"); + } + public static final String ADD_P = PREFIX + "ADD_P" + POSTFIX; static { beforeMatchingNameRegex(ADD_P, "AddP"); @@ -528,6 +538,11 @@ public class IRNode { beforeMatchingNameRegex(CONV, "Conv"); } + public static final String CONV_F2HF = PREFIX + "CONV_F2HF" + POSTFIX; + static { + beforeMatchingNameRegex(CONV_F2HF, "ConvF2HF"); + } + public static final String CONV_I2L = PREFIX + "CONV_I2L" + POSTFIX; static { beforeMatchingNameRegex(CONV_I2L, "ConvI2L"); @@ -538,6 +553,11 @@ public class IRNode { beforeMatchingNameRegex(CONV_L2I, "ConvL2I"); } + public static final String CONV_HF2F = PREFIX + "CONV_HF2F" + POSTFIX; + static { + beforeMatchingNameRegex(CONV_HF2F, "ConvHF2F"); + } + public static final String CON_I = PREFIX + "CON_I" + POSTFIX; static { beforeMatchingNameRegex(CON_I, "ConI"); @@ -652,6 +672,11 @@ public class IRNode { vectorNode(FMA_VD, "FmaVD", TYPE_DOUBLE); } + public static final String FMA_HF = PREFIX + "FMA_HF" + POSTFIX; + static { + beforeMatchingNameRegex(FMA_HF, "FmaHF"); + } + public static final String G1_COMPARE_AND_EXCHANGE_N_WITH_BARRIER_FLAG = COMPOSITE_PREFIX + "G1_COMPARE_AND_EXCHANGE_N_WITH_BARRIER_FLAG" + POSTFIX; static { String regex = START + "g1CompareAndExchangeN\\S*" + MID + "barrier\\(\\s*" + IS_REPLACED + "\\s*\\)" + END; @@ -1154,6 +1179,16 @@ public class IRNode { beforeMatchingNameRegex(MIN_L, "MinL"); } + public static final String MIN_HF = PREFIX + "MIN_HF" + POSTFIX; + static { + beforeMatchingNameRegex(MIN_HF, "MinHF"); + } + + public static final String MAX_HF = PREFIX + "MAX_HF" + POSTFIX; + static { + beforeMatchingNameRegex(MAX_HF, "MaxHF"); + } + public static final String MIN_VI = VECTOR_PREFIX + "MIN_VI" + POSTFIX; static { vectorNode(MIN_VI, "MinV", TYPE_INT); @@ -1235,6 +1270,11 @@ public class IRNode { beforeMatchingNameRegex(MUL_F, "MulF"); } + public static final String MUL_HF = PREFIX + "MUL_HF" + POSTFIX; + static { + beforeMatchingNameRegex(MUL_HF, "MulHF"); + } + public static final String MUL_I = PREFIX + "MUL_I" + POSTFIX; static { beforeMatchingNameRegex(MUL_I, "MulI"); @@ -1435,6 +1475,16 @@ public class IRNode { trapNodes(RANGE_CHECK_TRAP, "range_check"); } + public static final String REINTERPRET_S2HF = PREFIX + "REINTERPRET_S2HF" + POSTFIX; + static { + beforeMatchingNameRegex(REINTERPRET_S2HF, "ReinterpretS2HF"); + } + + public static final String REINTERPRET_HF2S = PREFIX + "REINTERPRET_HF2S" + POSTFIX; + static { + beforeMatchingNameRegex(REINTERPRET_HF2S, "ReinterpretHF2S"); + } + public static final String REPLICATE_B = VECTOR_PREFIX + "REPLICATE_B" + POSTFIX; static { vectorNode(REPLICATE_B, "Replicate", TYPE_BYTE); @@ -1601,6 +1651,16 @@ public class IRNode { vectorNode(SIGNUM_VF, "SignumVF", TYPE_FLOAT); } + public static final String SQRT_HF = PREFIX + "SQRT_HF" + POSTFIX; + static { + beforeMatchingNameRegex(SQRT_HF, "SqrtHF"); + } + + public static final String SQRT_F = PREFIX + "SQRT_F" + POSTFIX; + static { + beforeMatchingNameRegex(SQRT_F, "SqrtF"); + } + public static final String SQRT_VF = VECTOR_PREFIX + "SQRT_VF" + POSTFIX; static { vectorNode(SQRT_VF, "SqrtVF", TYPE_FLOAT); @@ -1742,6 +1802,11 @@ public class IRNode { beforeMatchingNameRegex(SUB_F, "SubF"); } + public static final String SUB_HF = PREFIX + "SUB_HF" + POSTFIX; + static { + beforeMatchingNameRegex(SUB_HF, "SubHF"); + } + public static final String SUB_I = PREFIX + "SUB_I" + POSTFIX; static { beforeMatchingNameRegex(SUB_I, "SubI"); @@ -1792,6 +1857,11 @@ public class IRNode { trapNodes(TRAP, "reason"); } + public static final String DIV_HF = PREFIX + "DIV_HF" + POSTFIX; + static { + beforeMatchingNameRegex(DIV_HF, "DivHF"); + } + public static final String UDIV_I = PREFIX + "UDIV_I" + POSTFIX; static { beforeMatchingNameRegex(UDIV_I, "UDivI"); diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java b/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java index 9f58f709702..0b7dcbae9d9 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -102,6 +102,7 @@ public class IREncodingPrinter { "avx512dq", "avx512vl", "avx512f", + "avx512_fp16", "avx512_vnni", // AArch64 "sha3", diff --git a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java index 20401a98938..4cf656620bc 100644 --- a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java +++ b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,7 +41,12 @@ import java.util.Arrays; public class TestFloat16VectorConvChain { @Test - @IR(applyIfCPUFeatureOr = {"f16c", "true", "avx512vl", "true", "zvfh", "true"}, counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) + @IR(applyIfCPUFeatureAnd = {"avx512_fp16", "false", "avx512vl", "true"}, + counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) + @IR(applyIfCPUFeatureAnd = {"avx512_fp16", "false", "f16c", "true"}, + counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) + @IR(applyIfCPUFeature = {"zvfh", "true"}, + counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) public static void test(short [] res, short [] src1, short [] src2) { for (int i = 0; i < res.length; i++) { res[i] = (short)Float.float16ToFloat(Float.floatToFloat16(Float.float16ToFloat(src1[i]) + Float.float16ToFloat(src2[i]))); diff --git a/test/jdk/jdk/incubator/vector/ScalarFloat16OperationsTest.java b/test/jdk/jdk/incubator/vector/ScalarFloat16OperationsTest.java new file mode 100644 index 00000000000..e28ba401197 --- /dev/null +++ b/test/jdk/jdk/incubator/vector/ScalarFloat16OperationsTest.java @@ -0,0 +1,347 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8342103 + * @summary C2 compiler support for Float16 type and associated operations + * @modules jdk.incubator.vector + * @library /test/lib + * @compile ScalarFloat16OperationsTest.java + * @run testng/othervm/timeout=300 -ea -esa -Xbatch -XX:-TieredCompilation -XX:-UseSuperWord ScalarFloat16OperationsTest + * @run testng/othervm/timeout=300 -ea -esa -Xbatch -XX:-TieredCompilation -XX:+UseSuperWord ScalarFloat16OperationsTest + */ + +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Random; +import java.util.stream.IntStream; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; + +public class ScalarFloat16OperationsTest { + static final int SIZE = 65504; + static Random r = jdk.test.lib.Utils.getRandomInstance(); + static final int INVOC_COUNT = Integer.getInteger("jdk.incubator.vector.test.loop-iterations", 100); + + @DataProvider + public static Object[][] unaryOpProvider() { + Float16 [] input = new Float16[SIZE]; + Float16 [] special_input = { + Float16.MAX_VALUE, Float16.MIN_VALUE, Float16.MIN_NORMAL, Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0f), Float16.valueOf(-0.0f), Float16.NaN + }; + + // Input array covers entire Float16 value range + IntStream.range(0, input.length).forEach(i -> {input[i] = valueOf(i);}); + + return new Object[][] { + {input}, + {special_input} + }; + } + + @DataProvider + public static Object[][] binaryOpProvider() { + Float16 [] input1 = new Float16[SIZE]; + Float16 [] input2 = new Float16[SIZE]; + Float16 [] special_input = { + Float16.MAX_VALUE, Float16.MIN_VALUE, Float16.MIN_NORMAL, Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0f), Float16.valueOf(-0.0f), Float16.NaN + }; + + // Input arrays covers entire Float16 value range interspersed with special values. + IntStream.range(0, input1.length).forEach(i -> {input1[i] = valueOf(i);}); + IntStream.range(0, input2.length).forEach(i -> {input2[i] = valueOf(i);}); + + for (int i = 0; i < special_input.length; i += 256) { + input1[r.nextInt(input1.length)] = special_input[i]; + input2[r.nextInt(input2.length)] = special_input[i]; + } + + return new Object[][] { + {input1, input2}, + {special_input, special_input}, + }; + } + + @DataProvider + public static Object[][] ternaryOpProvider() { + Float16 [] input1 = new Float16[SIZE]; + Float16 [] input2 = new Float16[SIZE]; + Float16 [] input3 = new Float16[SIZE]; + Float16 [] special_input = { + Float16.MAX_VALUE, Float16.MIN_VALUE, Float16.MIN_NORMAL, Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0f), Float16.valueOf(-0.0f), Float16.NaN + }; + + // Input arrays covers entire Float16 value range interspersed with special values. + IntStream.range(0, input1.length).forEach(i -> {input1[i] = valueOf(i);}); + IntStream.range(0, input2.length).forEach(i -> {input2[i] = valueOf(i);}); + IntStream.range(0, input3.length).forEach(i -> {input3[i] = valueOf(i);}); + for (int i = 0; i < special_input.length; i += 256) { + input1[r.nextInt(input1.length)] = special_input[i]; + input2[r.nextInt(input2.length)] = special_input[i]; + input3[r.nextInt(input3.length)] = special_input[i]; + } + + return new Object[][] { + {input1, input2, input3}, + {special_input, special_input, special_input}, + }; + } + + interface FUnOp1 { + Float16 apply(Float16 a); + } + + interface FUnOp2 { + boolean apply(Float16 a); + } + + static void assertArraysEquals(Float16[] r, Float16[] a, FUnOp1 f) { + int i = 0; + try { + for (; i < a.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i]), "at index #" + i + ", input = " + a[i]); + } + } + + static void assertArraysEquals(boolean[] r, Float16[] a, FUnOp2 f) { + int i = 0; + try { + for (; i < a.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i]), "at index #" + i + ", input = " + a[i]); + } + } + + interface FBinOp { + Float16 apply(Float16 a, Float16 b); + } + + static void assertArraysEquals(Float16[] r, Float16[] a, Float16[] b, FBinOp f) { + int i = 0; + try { + for (; i < r.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i], b[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i], b[i]), "at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i]); + } + } + + interface FTernOp { + Float16 apply(Float16 a, Float16 b, Float16 c); + } + + static void assertArraysEquals(Float16[] r, Float16[] a, Float16[] b, Float16[] c, FTernOp f) { + int i = 0; + try { + for (; i < r.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i], b[i], c[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i], b[i], c[i]), "at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i]); + } + } + + + @Test(dataProvider = "unaryOpProvider") + public static void absTest(Object input) { + Float16 [] farr = (Float16[])input; + Float16 [] res = new Float16[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = abs(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> valueOf(Math.abs(fp16.floatValue()))); + } + + @Test(dataProvider = "unaryOpProvider") + public static void negTest(Object input) { + Float16 [] farr = (Float16[])input; + Float16 [] res = new Float16[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = negate(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> shortBitsToFloat16((short)(float16ToRawShortBits(fp16) ^ (short)0x0000_8000))); + } + + @Test(dataProvider = "unaryOpProvider") + public static void sqrtTest(Object input) { + Float16 [] farr = (Float16[])input; + Float16 [] res = new Float16[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = sqrt(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> valueOf(Math.sqrt(fp16.floatValue()))); + } + + @Test(dataProvider = "unaryOpProvider") + public static void isInfiniteTest(Object input) { + Float16 [] farr = (Float16[])input; + boolean [] res = new boolean[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = isInfinite(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> Float.isInfinite(fp16.floatValue())); + } + + @Test(dataProvider = "unaryOpProvider") + public static void isFiniteTest(Object input) { + Float16 [] farr = (Float16[])input; + boolean [] res = new boolean[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = isFinite(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> Float.isFinite(fp16.floatValue())); + } + + @Test(dataProvider = "unaryOpProvider") + public static void isNaNTest(Object input) { + Float16 [] farr = (Float16[])input; + boolean [] res = new boolean[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = isNaN(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> Float.isNaN(fp16.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void addTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = add(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() + fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void subtractTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = subtract(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() - fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void multiplyTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = multiply(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() * fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void divideTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = divide(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() / fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void maxTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = max(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(Float.max(fp16_val1.floatValue(), fp16_val2.floatValue()))); + } + + @Test(dataProvider = "binaryOpProvider") + public static void minTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = min(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(Float.min(fp16_val1.floatValue(), fp16_val2.floatValue()))); + } + + @Test(dataProvider = "ternaryOpProvider") + public static void fmaTest(Object input1, Object input2, Object input3) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + Float16 [] farr3 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = fma(farr1[i], farr2[i], farr3[i]); + } + } + assertArraysEquals(res, farr1, farr2, farr3, (fp16_val1, fp16_val2, fp16_val3) -> valueOf(Math.fma(fp16_val1.floatValue(), fp16_val2.floatValue(), fp16_val3.floatValue()))); + } +} diff --git a/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java b/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java new file mode 100644 index 00000000000..ca8c0edfc74 --- /dev/null +++ b/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.lang; + +import java.util.stream.IntStream; +import java.util.concurrent.TimeUnit; +import jdk.incubator.vector.*; +import org.openjdk.jmh.annotations.*; +import static jdk.incubator.vector.Float16.*; +import static java.lang.Float.*; + +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector", "-Xbatch", "-XX:-TieredCompilation"}) +public class Float16OperationsBenchmark { + @Param({"256", "512", "1024", "2048"}) + int vectorDim; + + int [] rexp; + short [] vectorRes; + short [] vector1; + short [] vector2; + short [] vector3; + boolean [] vectorPredicate; + + static final short f16_one = Float.floatToFloat16(1.0f); + static final short f16_two = Float.floatToFloat16(2.0f); + + @Setup(Level.Trial) + public void BmSetup() { + rexp = new int[vectorDim]; + vectorRes = new short[vectorDim]; + vector1 = new short[vectorDim]; + vector2 = new short[vectorDim]; + vector3 = new short[vectorDim]; + vectorPredicate = new boolean[vectorDim]; + + IntStream.range(0, vectorDim).forEach(i -> {vector1[i] = Float.floatToFloat16((float)i);}); + IntStream.range(0, vectorDim).forEach(i -> {vector2[i] = Float.floatToFloat16((float)i);}); + IntStream.range(0, vectorDim).forEach(i -> {vector3[i] = Float.floatToFloat16((float)i);}); + + // Special Values + Float16 [] specialValues = {Float16.NaN, Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0), Float16.valueOf(-0.0), Float16.POSITIVE_INFINITY}; + IntStream.range(0, vectorDim).forEach( + i -> { + if ((i % 64) == 0) { + int idx1 = i % specialValues.length; + int idx2 = (i + 1) % specialValues.length; + int idx3 = (i + 2) % specialValues.length; + vector1[i] = float16ToRawShortBits(specialValues[idx1]); + vector2[i] = float16ToRawShortBits(specialValues[idx2]); + vector3[i] = float16ToRawShortBits(specialValues[idx3]); + } + } + ); + } + + @Benchmark + public void addBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(add(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void subBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void mulBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(multiply(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void divBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(divide(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void fmaBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(fma(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]), shortBitsToFloat16(vector3[i]))); + } + } + + @Benchmark + public boolean isInfiniteBenchmark() { + boolean res = true; + for (int i = 0; i < vectorDim; i++) { + res &= isInfinite(shortBitsToFloat16(vector1[i])); + } + return res; + } + + @Benchmark + public boolean isFiniteBenchmark() { + boolean res = true; + for (int i = 0; i < vectorDim; i++) { + res &= isFinite(shortBitsToFloat16(vector1[i])); + } + return res; + } + + @Benchmark + public boolean isNaNBenchmark() { + boolean res = true; + for (int i = 0; i < vectorDim; i++) { + res &= isNaN(shortBitsToFloat16(vector1[i])); + } + return res; + } + + @Benchmark + public void isNaNStoreBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorPredicate[i] = Float16.isNaN(shortBitsToFloat16(vector1[i])); + } + } + + + @Benchmark + public void isNaNCMovBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = Float16.isNaN(shortBitsToFloat16(vector1[i])) ? f16_one : f16_two; + } + } + + + @Benchmark + public void isInfiniteStoreBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorPredicate[i] = Float16.isInfinite(shortBitsToFloat16(vector1[i])); + } + } + + + @Benchmark + public void isInfiniteCMovBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = Float16.isInfinite(shortBitsToFloat16(vector1[i])) ? f16_one : f16_two; + } + } + + + @Benchmark + public void isFiniteStoreBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorPredicate[i] = Float16.isFinite(shortBitsToFloat16(vector1[i])); + } + } + + + @Benchmark + public void isFiniteCMovBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = Float16.isFinite(shortBitsToFloat16(vector1[i])) ? f16_one : f16_two; + } + } + + @Benchmark + public void maxBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(max(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void minBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(min(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void sqrtBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(sqrt(shortBitsToFloat16(vector1[i]))); + } + } + + @Benchmark + public void negateBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(negate(shortBitsToFloat16(vector1[i]))); + } + } + + @Benchmark + public void absBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(abs(shortBitsToFloat16(vector1[i]))); + } + } + + @Benchmark + public void getExponentBenchmark() { + for (int i = 0; i < vectorDim; i++) { + rexp[i] = getExponent(shortBitsToFloat16(vector1[i])); + } + } + + @Benchmark + public short cosineSimilarityDoubleRoundingFP16() { + short macRes = floatToFloat16(0.0f); + short vector1Square = floatToFloat16(0.0f); + short vector2Square = floatToFloat16(0.0f); + for (int i = 0; i < vectorDim; i++) { + // Explicit add and multiply operation ensures double rounding. + Float16 vec1 = shortBitsToFloat16(vector1[i]); + Float16 vec2 = shortBitsToFloat16(vector2[i]); + macRes = float16ToRawShortBits(add(multiply(vec1, vec2), shortBitsToFloat16(macRes))); + vector1Square = float16ToRawShortBits(add(multiply(vec1, vec1), shortBitsToFloat16(vector1Square))); + vector2Square = float16ToRawShortBits(add(multiply(vec2, vec2), shortBitsToFloat16(vector2Square))); + } + return float16ToRawShortBits(divide(shortBitsToFloat16(macRes), add(shortBitsToFloat16(vector1Square), shortBitsToFloat16(vector2Square)))); + } + + @Benchmark + public short cosineSimilaritySingleRoundingFP16() { + short macRes = floatToFloat16(0.0f); + short vector1Square = floatToFloat16(0.0f); + short vector2Square = floatToFloat16(0.0f); + for (int i = 0; i < vectorDim; i++) { + Float16 vec1 = shortBitsToFloat16(vector1[i]); + Float16 vec2 = shortBitsToFloat16(vector2[i]); + macRes = float16ToRawShortBits(fma(vec1, vec2, shortBitsToFloat16(macRes))); + vector1Square = float16ToRawShortBits(fma(vec1, vec1, shortBitsToFloat16(vector1Square))); + vector2Square = float16ToRawShortBits(fma(vec2, vec2, shortBitsToFloat16(vector2Square))); + } + return float16ToRawShortBits(divide(shortBitsToFloat16(macRes), add(shortBitsToFloat16(vector1Square), shortBitsToFloat16(vector2Square)))); + } + + @Benchmark + public short cosineSimilarityDequantizedFP16() { + float macRes = 0.0f; + float vector1Square = 0.0f; + float vector2Square = 0.0f; + for (int i = 0; i < vectorDim; i++) { + float vec1 = float16ToFloat(vector1[i]); + float vec2 = float16ToFloat(vector2[i]); + macRes = Math.fma(vec1, vec2, macRes); + vector1Square = Math.fma(vec1, vec1, vector1Square); + vector2Square = Math.fma(vec2, vec2, vector2Square); + } + return floatToFloat16(macRes / (vector1Square + vector2Square)); + } + + @Benchmark + public short euclideanDistanceFP16() { + short distRes = floatToFloat16(0.0f); + short squareRes = floatToFloat16(0.0f); + for (int i = 0; i < vectorDim; i++) { + squareRes = float16ToRawShortBits(subtract(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + distRes = float16ToRawShortBits(fma(shortBitsToFloat16(squareRes), shortBitsToFloat16(squareRes), shortBitsToFloat16(distRes))); + } + return float16ToRawShortBits(sqrt(shortBitsToFloat16(distRes))); + } + + @Benchmark + public short euclideanDistanceDequantizedFP16() { + float distRes = 0.0f; + float squareRes = 0.0f; + for (int i = 0; i < vectorDim; i++) { + squareRes = float16ToFloat(vector1[i]) - float16ToFloat(vector2[i]); + distRes = distRes + squareRes * squareRes; + } + return float16ToRawShortBits(sqrt(shortBitsToFloat16(floatToFloat16(distRes)))); + } +}