Vector performance issue.

Jake Luciani jake at apache.org
Mon Sep 18 15:45:28 UTC 2023


Looking at the code I wonder if it's this extra branch?

@ForceInline
final
float reduceLanesTemplate(VectorOperators.Associative op,
                           Class<? extends VectorMask<Float>> maskClass,
                           VectorMask<Float> m) {
    m.check(maskClass, this);
    if (op == FIRST_NONZERO) {
        // FIXME:  The JIT should handle this.
        FloatVector v = broadcast((float) 0).blend(this, m);
        return v.reduceLanesTemplate(op);
    }
    int opc = opCode(op);
    return fromBits(VectorSupport.reductionCoerced(
        opc, getClass(), maskClass, float.class, length(),
        this, m,
        REDUCE_IMPL.find(op, opc, FloatVector::reductionOperations)));
}

On Mon, Sep 18, 2023 at 11:11 AM Andrii Lomakin
<lomakin.andrey at gmail.com> wrote:
>
> Hi,
> I have the same problem during calculation of Eucledian distance in my project too.
> Writing just to confirm that it is not a single case and I have got the same result during profiling.
>
> On Sat, Sep 16, 2023 at 9:50 PM Jake Luciani <jake at apache.org> wrote:
>>
>> Hi,
>>
>> I've been struggling with a problem recently using the vector api.
>> It appears as reduceLanes is not using the intrinsic.
>>
>>           ns  percent  samples  top
>>   ----------  -------  -------  ---
>>  13240151836   88.21%     1324
>> jdk.incubator.vector.FloatVector.reduceLanesTemplate
>>   1349991099    8.99%      135
>> jdk.incubator.vector.FloatVector.lanewiseTemplate
>>
>> I've tested openjdk 20 and 21 and my machine has AVX512.
>>
>>  When I PrintIntrinsics I see the following (among others):
>>
>>   ** missing constant: opr=RShiftI vclass=ConP etype=ConP vlen=ConI
>>
>> I've included a JMH benchmark that reproduces the issue.
>>
>> -Jake
>>
>> import jdk.incubator.vector.FloatVector;
>> import jdk.incubator.vector.IntVector;
>> import jdk.incubator.vector.ShortVector;
>> import jdk.incubator.vector.VectorOperators;
>> import org.openjdk.jmh.annotations.*;
>> import org.openjdk.jmh.infra.Blackhole;
>>
>> import java.util.concurrent.ThreadLocalRandom;
>> import java.util.concurrent.TimeUnit;
>>
>>
>> @Warmup(iterations = 1, time = 5)
>> @Measurement(iterations = 3, time = 5)
>> @Fork(warmups = 1, value = 1, jvmArgsPrepend = {
>>         "--add-modules=jdk.incubator.vector",
>>         "--enable-preview"})
>> public class VectorPerfBench
>> {
>>     private static final int SIZE = 8192;
>>     private static final IntVector BF16_BYTE_SHIFT =
>> IntVector.broadcast(IntVector.SPECIES_512, 16);
>>
>>     public static short float32ToBFloat16(float f) {
>>         return (short) (Float.floatToIntBits(f) >> 16);
>>     }
>>     @State(Scope.Benchmark)
>>     public static class Parameters {
>>         final short[] s1 = new short[SIZE];
>>         final short[] s2 = new short[SIZE];
>>
>>         public Parameters() {
>>             for (int i = 0; i < SIZE; i++) {
>>                 s1[i] =
>> float32ToBFloat16(ThreadLocalRandom.current().nextFloat());
>>                 s2[i] =
>> float32ToBFloat16(ThreadLocalRandom.current().nextFloat());
>>             }
>>         }
>>     }
>>
>>     @Benchmark
>>     @OutputTimeUnit(TimeUnit.MILLISECONDS)
>>     @BenchmarkMode(Mode.Throughput)
>>     public void bfloatDot(Parameters p, Blackhole bh) {
>>         FloatVector acc = FloatVector.zero(FloatVector.SPECIES_512);
>>         for (int i = 0; i < SIZE; i += FloatVector.SPECIES_512.length()) {
>>
>>             var f1 = ShortVector.fromArray(ShortVector.SPECIES_256, p.s1, i)
>>                     .convertShape(VectorOperators.ZERO_EXTEND_S2I,
>> IntVector.SPECIES_512, 0)
>>                     .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT)
>>                     .reinterpretAsFloats();
>>
>>             var f2 = ShortVector.fromArray(ShortVector.SPECIES_256, p.s2, i)
>>                     .convertShape(VectorOperators.ZERO_EXTEND_S2I,
>> IntVector.SPECIES_512, 0)
>>                     .lanewise(VectorOperators.LSHL, BF16_BYTE_SHIFT)
>>                     .reinterpretAsFloats();
>>
>>             acc = acc.add(f1.mul(f2));
>>         }
>>
>>         bh.consume(acc.reduceLanes(VectorOperators.ADD));
>>     }
>>
>>     public static void main(String[] args) throws Exception {
>>         org.openjdk.jmh.Main.main(args);
>>     }
>> }
>
>
>
> --
> Best regards,
> Andrii Lomakin.
>


More information about the panama-dev mailing list