[VectorAPI]anyTrue/allTrue got wrong results when using UseAVX=3

王卓(卓仁) zhuoren.wz at alibaba-inc.com
Wed Jan 2 09:47:04 UTC 2019


Hello, I found anyTrue/allTrue got wrong results when using UseAVX=3.

Let me describe this bug. Under UseAVX=3, following assembly code may be generated for anyTrue/allTrue. 

  0x00002ad5dca584f6: vpxord %xmm22,%xmm22,%xmm22
  0x00002ad5dca584fc: vpsubb %xmm0,%xmm22,%xmm22
  0x00002ad5dca58502: vpmovsxbd %xmm22,%ymm22
  0x00002ad5dca58508: vptest %ymm1,%ymm6
  0x00002ad5dca5850d: setb   %r10b

The second operand of vptest should be ymm22. But since there is no EVEX version for vptest, it is encoded as ymm6 and gives us wrong results.

Currently I am using the following fix for this bug. Move source operands into ymm14/ymm15/xmm14/xmm15,  and then they are used as operands of vptest. Please give advise on this fix.
diff -r 636478e1ee75 src/hotspot/cpu/x86/x86.ad
--- a/src/hotspot/cpu/x86/x86.ad        Thu Dec 13 16:40:28 2018 -0800
+++ b/src/hotspot/cpu/x86/x86.ad        Sat Dec 29 11:16:55 2018 +0800
@@ -21930,7 +21930,7 @@
 %}

 instruct vptest4ieq(rRegI dst, vecX src1, vecX src2) %{
-  predicate(UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::notZero);
+  predicate(UseAVX < 3 && UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::notZero);
   match(Set dst (VectorTest src1 src2 ));
   format %{ "vptest  $src1,$src2\n\t"
             "setb  $dst\t!" %}
@@ -21943,8 +21943,54 @@
   ins_pipe( pipe_slow );
 %}

+instruct vptest4inaeavx3(rRegI dst, vecX src1, vecX src2, rxmm14 tmp14, rxmm15 tmp15) %{
+  predicate(UseAVX > 2 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::carrySet);
+  match(Set dst (VectorTest src1 src2 ));
+  effect(TEMP tmp14, TEMP tmp15);
+  format %{ "movdqu      $tmp14,$src1\n\t"
+            "movdqu      $tmp15,$src2\n\t"
+            "vptest  $tmp14,$tmp15\n\t"
+            "setb  $dst\t!" %}
+  ins_encode %{
+    int vector_len = 0;
+    if ($tmp14$$XMMRegister != $src1$$XMMRegister) {
+      __ movdqu($tmp14$$XMMRegister, $src1$$XMMRegister);
+    }
+    if ($tmp15$$XMMRegister != $src2$$XMMRegister) {
+      __ movdqu($tmp15$$XMMRegister, $src2$$XMMRegister);
+    }
+    __ vptest($tmp14$$XMMRegister, $tmp15$$XMMRegister, vector_len);
+    __ setb(Assembler::carrySet, $dst$$Register);
+    __ movzbl($dst$$Register, $dst$$Register);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vptest4ieqavx3(rRegI dst, vecX src1, vecX src2, rxmm14 tmp14, rxmm15 tmp15) %{
+  predicate(UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::notZero);
+  match(Set dst (VectorTest src1 src2 ));
+  effect(TEMP tmp14, TEMP tmp15);
+  format %{ "movdqu      $tmp14,$src1\n\t"
+            "movdqu      $tmp15,$src2\n\t"
+            "vptest  $tmp14,$tmp15\n\t"
+            "setb  $dst\t!" %}
+  ins_encode %{
+    int vector_len = 0;
+    if ($tmp14$$XMMRegister != $src1$$XMMRegister) {
+      __ movdqu($tmp14$$XMMRegister, $src1$$XMMRegister);
+    }
+    if ($tmp15$$XMMRegister != $src2$$XMMRegister) {
+      __ movdqu($tmp15$$XMMRegister, $src2$$XMMRegister);
+    }
+    __ vptest($tmp14$$XMMRegister, $tmp15$$XMMRegister, vector_len);
+    __ setb(Assembler::notZero, $dst$$Register);
+    __ movzbl($dst$$Register, $dst$$Register);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
 instruct vptest8inae(rRegI dst, vecY src1, vecY src2) %{
-  predicate(UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::carrySet);
+  predicate(UseAVX < 3 && UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::carrySet);
   match(Set dst (VectorTest src1 src2 ));
   format %{ "vptest  $src1,$src2\n\t"
             "setb  $dst\t!" %}
@@ -21958,7 +22004,7 @@
 %}

 instruct vptest8ieq(rRegI dst, vecY src1, vecY src2) %{
-  predicate(UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::notZero);
+  predicate(UseAVX < 3 && UseAVX > 0 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::notZero);
   match(Set dst (VectorTest src1 src2 ));
   format %{ "vptest  $src1,$src2\n\t"
             "setb  $dst\t!" %}
@@ -21971,6 +22017,52 @@
   ins_pipe( pipe_slow );
 %}
+instruct vptest8inaeavx3(rRegI dst, vecY src1, vecY src2, rymm14 tmp14, rymm15 tmp15) %{
+  predicate(UseAVX > 2 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::carrySet);
+  match(Set dst (VectorTest src1 src2 ));
+  effect(TEMP tmp14, TEMP tmp15);
+  format %{ "movdqu      $tmp14,$src1\n\t"
+            "movdqu      $tmp15,$src2\n\t"
+            "vptest  $tmp14,$tmp15\n\t"
+            "setb  $dst\t!" %}
+  ins_encode %{
+    int vector_len = 1;
+    if ($tmp14$$XMMRegister != $src1$$XMMRegister) {
+      __ vmovdqu($tmp14$$XMMRegister, $src1$$XMMRegister);
+    }
+    if ($tmp15$$XMMRegister != $src2$$XMMRegister) {
+      __ vmovdqu($tmp15$$XMMRegister, $src2$$XMMRegister);
+    }
+    __ vptest($tmp14$$XMMRegister, $tmp15$$XMMRegister, vector_len);
+    __ setb(Assembler::carrySet, $dst$$Register);
+    __ movzbl($dst$$Register, $dst$$Register);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
+instruct vptest8ieqavx3(rRegI dst, vecY src1, vecY src2, rymm14 tmp14, rymm15 tmp15) %{
+  predicate(UseAVX > 2 && static_cast<const VectorTestNode*>(n)->get_predicate() == Assembler::notZero);
+  match(Set dst (VectorTest src1 src2 ));
+  effect(TEMP tmp14, TEMP tmp15);
+  format %{ "movdqu      $tmp14,$src1\n\t"
+            "movdqu      $tmp15,$src2\n\t"
+            "vptest  $tmp14,$tmp15\n\t"
+            "setb  $dst\t!" %}
+  ins_encode %{
+    int vector_len = 1;
+    if ($tmp14$$XMMRegister != $src1$$XMMRegister) {
+      __ vmovdqu($tmp14$$XMMRegister, $src1$$XMMRegister);
+    }
+    if ($tmp15$$XMMRegister != $src2$$XMMRegister) {
+      __ vmovdqu($tmp15$$XMMRegister, $src2$$XMMRegister);
+    }
+    __ vptest($tmp14$$XMMRegister, $tmp15$$XMMRegister, vector_len);
+    __ setb(Assembler::notZero, $dst$$Register);
+    __ movzbl($dst$$Register, $dst$$Register);
+  %}
+  ins_pipe( pipe_slow );
+%}
+
 instruct loadmask8b(vecD dst, vecD src) %{
   predicate(UseSSE >= 2 && n->as_Vector()->length() == 8 && n->bottom_type()->is_vect()->element_basic_type() == T_BYTE);
   match(Set dst (VectorLoadMask src));
diff -r 636478e1ee75 src/hotspot/cpu/x86/x86_64.ad
--- a/src/hotspot/cpu/x86/x86_64.ad     Thu Dec 13 16:40:28 2018 -0800
+++ b/src/hotspot/cpu/x86/x86_64.ad     Sat Dec 29 11:16:55 2018 +0800
@@ -4426,6 +4426,17 @@
   predicate(UseAVX == 3);  format%{%}  interface(REG_INTER);
 %}

+
+operand rymm14() %{
+  constraint(ALLOC_IN_RC(ymm14_reg));  match(VecY);
+  predicate((UseSSE > 0) && (UseAVX <= 3));  format%{%}  interface(REG_INTER);
+%}
+operand rymm15() %{
+  constraint(ALLOC_IN_RC(ymm15_reg));  match(VecY);
+  predicate((UseSSE > 0) && (UseAVX <= 3));  format%{%}  interface(REG_INTER);
+%}
+
+
 //----------OPERAND CLASSES----------------------------------------------------
 // Operand Classes are groups of operands that are used as to simplify
 // instruction definitions by not requiring the AD writer to specify separate


The test to reproduce this bug. Please DO set -XX:UseAVX=3

import jdk.incubator.vector.*;
import java.util.Arrays;
import java.util.Random;
import java.lang.reflect.Field;
import java.io.IOException;
import jdk.incubator.vector.Vector.Mask;
import jdk.incubator.vector.Vector.Shape;
public class VectorTrueTest
{

    static Random random = new Random();
    static final IntVector.IntSpecies Species = IntVector.species(Vector.Shape.S_256_BIT);
    public static int size = 1024 * 16;
    public static int length = Species.length();
    public static int resultSize = size / length;
    static boolean[] anyResultV = new boolean[resultSize];
    static boolean[] allResultV = new boolean[resultSize];
    static boolean[] anyInput = new boolean[size];
    static boolean[] allInput = new boolean[size];
    public static void main(String[] args) throws  NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException, InstantiationException {
        long start0 = System.currentTimeMillis();
        long startv = System.currentTimeMillis();
        long normalTime = 0;
        long vecTime = 0;
        int i = 0;
        for (i = 0; i < size; i++) {
            anyInput[i] = false;
            allInput[i] = false;
        }
        for (i = 0; i < 20000; i++) {
            vecTest(Species);
        }
        for (i = 0; i < resultSize; i++) {
            anyResultV[i] = true;
            allResultV[i] = true;
        }
        vecTest(Species);
        for (i = 0; i < (resultSize - 1); i++) {
            if (anyResultV[i] != false) throw new RuntimeException("Wrong anyTrue result! Should be all false, index " + i);
            if (allResultV[i] != false) throw new RuntimeException("Wrong allTrue result! Should be all false, index " + i);
        }
    }
    static void vecTest(IntVector.IntSpecies Speciesint ) {
        IntVector v0;
        int i = 0;
        int j = 0;
        Mask maskAny = Speciesint.maskFromArray(anyInput, i);
        Mask maskAll = Speciesint.maskFromArray(allInput, i);
        for (i = 0; i + (Speciesint.length()) <= size; i += Speciesint.length()) {
            allResultV[j] = maskAll.allTrue();
            anyResultV[j] = maskAny.anyTrue();
            j++;
        }
        return;
    }
}

Regards,
Zhuoren






More information about the panama-dev mailing list