Comments / metadata in assembly listings don't make sense for code vectorized using Vector API

Piotr Tarsa piotr.tarsa at gmail.com
Wed Dec 30 14:17:50 UTC 2020


Hi all,

Thanks for creating Project Panama! It looks promising. However, I've
made a try to vectorize some code and got somewhat disappointing
results. Therefore I wanted to look at the generated machine code to
see it it looks optimal. I've attached hsdis to JVM and enabled
assembly printing but the output doesn't make sense to me, i.e. the
instructions and comments / metadata don't seem to match. I may be
wrong as I've very rarely looked at assembly listing produced by JVM.

Performance:
As a baseline I took
https://benchmarksgame-team.pages.debian.net/benchmarksgame/program/mandelbrot-java-2.html
which takes about 3.05s to finish on my system. After vectorization
I've managed to achieve timings like 1.80s. That's quite disappointing
to me as I have a Haswell machine which has AVX2, high speed L1
caches, etc I've tested on recent JDK 16 EA build from
http://jdk.java.net/16/

Link to the code and assembly listing:
https://gist.github.com/tarsa/7a9c80bb84c2dcd807be9cd16a655ee0 I'll
copy the source code again in this mail at the end.

What I see in the assembly listings is e.g.

0x00007f0e208b8ab9:   cmp    r13d,0x7fffffc0
0x00007f0e208b8ac0:   jg     0x00007f0e208b932c
0x00007f0e208b8ac6:   vmulpd ymm0,ymm6,ymm4
0x00007f0e208b8aca:   vsubpd ymm1,ymm4,ymm4
0x00007f0e208b8ace:   vmovdqu YMMWORD PTR [rsp+0xc0],ymm1
0x00007f0e208b8ad7:   vmulpd ymm0,ymm0,ymm4
;*synchronization entry
                                                          ; -
jdk.internal.vm.vector.VectorSupport$VectorPayload::getPayload at -1
(line 101)
                                                          ; -
jdk.incubator.vector.Double256Vector$Double256Mask::getBits at 1 (line
557)
                                                          ; -
jdk.incubator.vector.AbstractMask::toLong at 24 (line 77)
                                                          ; -
mandelbrot_simd_1::computeChunksVector at 228 (line 187)
0x00007f0e208b8adb:   vaddpd ymm0,ymm0,ymm2               ;*checkcast
{reexecute=0 rethrow=0 return_oop=0}
                                                          ; -
jdk.incubator.vector.DoubleVector::fromArray0Template at 34 (line 3119)
                                                          ; -
jdk.incubator.vector.Double256Vector::fromArray0 at 3 (line 777)
                                                          ; -
jdk.incubator.vector.DoubleVector::fromArray at 24 (line 2564)
                                                          ; -
mandelbrot_simd_1::computeChunksVector at 95 (line 169)
0x00007f0e208b8adf:   vmovdqu YMMWORD PTR [rsp+0xe0],ymm0
0x00007f0e208b8ae8:   vmulpd ymm0,ymm0,ymm0
0x00007f0e208b8aec:   vmovdqu YMMWORD PTR [rsp+0x100],ymm0

How does vmulpd relate to a synchronization entry and
AbstrackMask::toLong? It seems way off to me. However, there maybe
some trick to understand it. Could you give me some guidelines on how
to intepret that? Are the comments describing lines below or above
them?

Regards,
Piotr

mandelbrot_simd_1.java source code:
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;

import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

public class mandelbrot_simd_1 {
    private static final VectorSpecies<Double> SPECIES =
            DoubleVector.SPECIES_PREFERRED.length() <= 8 ?
                    DoubleVector.SPECIES_PREFERRED : DoubleVector.SPECIES_512;

    private static final int LANES = SPECIES.length();

    private static final int LANES_LOG = Integer.numberOfTrailingZeros(LANES);

    public static void main(String[] args) throws IOException {
        if ((LANES > 8) || (LANES != (1 << LANES_LOG))) {
            var errorMsg = "LANES must be a power of two and at most 8. " +
                    "Change SPECIES in the source code.";
            throw new RuntimeException(errorMsg);
        }
        var sideLen = Integer.parseInt(args[0]);
        try (var out = new BufferedOutputStream(makeOut1())) {
            out.write(String.format("P4\n%d %d\n", sideLen,
sideLen).getBytes());
            computeAndOutputRows(out, sideLen);
        }
    }

    @SuppressWarnings("unused")
    // the version that avoids mixing up output with JVM diagnostic messages
    private static OutputStream makeOut1() throws IOException {
        return Files.newOutputStream(Path.of("mandelbrot_simd_1.pbm"));
    }

    // the version that is compatible with benchmark requirements
    private static OutputStream makeOut2() {
        return System.out;
    }

    private static void computeAndOutputRows(OutputStream out, int sideLen) {
        var poolFactor = 1000000 / sideLen;
        if (poolFactor < 10) {
            throw new RuntimeException("Too small poolFactor");
        }
        var numCpus = Runtime.getRuntime().availableProcessors();
        var rowsPerBatch = numCpus * poolFactor;
        var fac = 2.0 / sideLen;
        var aCr = IntStream.range(0, sideLen).parallel()
                .mapToDouble(x -> x * fac - 1.5).toArray();
        var bitsReversalMapping = computeBitsReversalMapping();
        var rowsPools = new byte[2][rowsPerBatch][(sideLen + 7) / 8];
        var rowsChunksPools = new long[2][rowsPerBatch][sideLen / 64];
        var batchSizes = new int[2];
        var batchCountDowns = new CountDownLatch[2];
        var computeEc = Executors.newWorkStealingPool(numCpus);
        var masterThread = new Thread(() -> {
            var rowsToProcess = sideLen;
            var nextBatchStart = 0;
            batchSizes[0] = 0;
            batchCountDowns[0] = new CountDownLatch(0);
            for (var poolId = 0; rowsToProcess > 0; poolId ^= 1) {
                while (batchCountDowns[poolId].getCount() != 0) {
                    try {
                        batchCountDowns[poolId].await();
                    } catch (InterruptedException ignored) {
                    }
                }
                batchCountDowns[poolId] = null;

                var nextBatchSize =
                        Math.min(sideLen - nextBatchStart, rowsPerBatch);
                var nextPoolId = poolId ^ 1;
                batchSizes[nextPoolId] = nextBatchSize;
                batchCountDowns[nextPoolId] = new CountDownLatch(nextBatchSize);
                sendTasks(fac, aCr, bitsReversalMapping,
                        rowsPools[nextPoolId], rowsChunksPools[nextPoolId],
                        nextBatchStart, nextBatchSize,
                        batchCountDowns[nextPoolId], computeEc);
                nextBatchStart += nextBatchSize;

                var batchSize = batchSizes[poolId];
                try {
                    for (var rowIdx = 0; rowIdx < batchSize; rowIdx++) {
                        out.write(rowsPools[poolId][rowIdx]);
                    }
                    out.flush();
                } catch (IOException e) {
                    e.printStackTrace();
                    System.exit(-1);
                }
                rowsToProcess -= batchSize;
            }

            computeEc.shutdown();
        });
        masterThread.start();
        while (masterThread.isAlive() || !computeEc.isTerminated()) {
            try {
                @SuppressWarnings("unused")
                var ignored = computeEc.awaitTermination(1, TimeUnit.DAYS);
                masterThread.join();
            } catch (InterruptedException ignored) {
            }
        }
    }

    private static void sendTasks(double fac, double[] aCr,
                                  byte[] bitsReversalMapping,
                                  byte[][] rows, long[][] rowsChunks,
                                  int batchStart, int batchSize,
                                  CountDownLatch poolsActiveWorkersCount,
                                  ExecutorService computeEc) {
        for (var i = 0; i < batchSize; i++) {
            var indexInBatch = i;
            var y = batchStart + i;
            var Ci = y * fac - 1.0;
            computeEc.submit(() -> {
                try {
                    computeRow(Ci, aCr, bitsReversalMapping,
                            rows[indexInBatch], rowsChunks[indexInBatch]);
                    poolsActiveWorkersCount.countDown();
                } catch (Exception e) {
                    e.printStackTrace();
                    System.exit(-1);
                }
            });
        }
    }

    private static byte[] computeBitsReversalMapping() {
        var bitsReversalMapping = new byte[256];
        for (var i = 0; i < 256; i++) {
            bitsReversalMapping[i] = (byte) (Integer.reverse(i) >>> 24);
        }
        return bitsReversalMapping;
    }

    private static void computeRow(double Ci, double[] aCr,
                                   byte[] bitsReversalMapping,
                                   byte[] row, long[] rowChunks) {
        computeChunksVector(Ci, aCr, rowChunks);
        transferRowFlags(rowChunks, row, bitsReversalMapping);
        computeRemainderScalar(aCr, row, Ci);
    }

    private static void computeChunksVector(double Ci, double[] aCr,
                                            long[] rowChunks) {
        var sideLen = aCr.length;
        var vCi = DoubleVector.broadcast(SPECIES, Ci);
        var vZeroes = DoubleVector.zero(SPECIES);
        var vTwos = DoubleVector.broadcast(SPECIES, 2.0);
        var vFours = DoubleVector.broadcast(SPECIES, 4.0);
        var zeroMask = VectorMask.fromLong(SPECIES, 0);
        // (1 << 6) = 64 = length of long in bits
        for (var xBase = 0; xBase < (sideLen & ~(1 << 6)); xBase += (1 << 6)) {
            var cmpFlags = 0L;
            for (var xInc = 0; xInc < (1 << 6); xInc += LANES) {
                var vZr = vZeroes;
                var vZi = vZeroes;
                var vCr = DoubleVector.fromArray(SPECIES, aCr, xBase + xInc);
                var vZrN = vZeroes;
                var vZiN = vZeroes;
                var cmpMask = zeroMask;
                for (var outer = 0; outer < 10; outer++) {
                    for (var inner = 0; inner < 5; inner++) {
                        vZi = vTwos.mul(vZr).mul(vZi).add(vCi);
                        vZr = vZrN.sub(vZiN).add(vCr);
                        vZiN = vZi.mul(vZi);
                        vZrN = vZr.mul(vZr);
                    }
                    cmpMask = cmpMask.or(vFours.lt(vZiN.add(vZrN)));
                    // in Rust version this works fine, so where's the bug then?
                    // cmpMask = vFours.lt(vZiN.add(vZrN));
                    if (cmpMask.allTrue()) {
                        break;
                    }
                }
                cmpFlags |= cmpMask.toLong() << xInc;
            }
            rowChunks[xBase >> 6] = cmpFlags;
        }
    }

    private static void transferRowFlags(long[] rowChunks, byte[] row,
                                         byte[] bitsReversalMapping) {
        for (var i = 0; i < rowChunks.length; i++) {
            var group = ~rowChunks[i];
            row[i * 8 + 7] = bitsReversalMapping[0xff & (byte) (group >>> 56)];
            row[i * 8 + 6] = bitsReversalMapping[0xff & (byte) (group >>> 48)];
            row[i * 8 + 5] = bitsReversalMapping[0xff & (byte) (group >>> 40)];
            row[i * 8 + 4] = bitsReversalMapping[0xff & (byte) (group >>> 32)];
            row[i * 8 + 3] = bitsReversalMapping[0xff & (byte) (group >>> 24)];
            row[i * 8 + 2] = bitsReversalMapping[0xff & (byte) (group >>> 16)];
            row[i * 8 + 1] = bitsReversalMapping[0xff & (byte) (group >>> 8)];
            row[i * 8] = bitsReversalMapping[0xff & (byte) group];
        }
    }

    private static void computeRemainderScalar(double[] aCr, byte[]
row, double Ci) {
        var sideLen = aCr.length;
        var bits = 0;
        for (var x = sideLen & ~(1 << 6); x < sideLen; x++) {
            var Zr = 0.0;
            var Zi = 0.0;
            var Cr = aCr[x];
            var i = 50;
            var ZrN = 0.0;
            var ZiN = 0.0;
            do {
                Zi = 2.0 * Zr * Zi + Ci;
                Zr = ZrN - ZiN + Cr;
                ZiN = Zi * Zi;
                ZrN = Zr * Zr;
            } while (ZiN + ZrN <= 4.0 && --i > 0);
            bits <<= 1;
            bits += i == 0 ? 1 : 0;
            if (x % 8 == 7) {
                row[x / 8] = (byte) bits;
                bits = 0;
            }
        }
        if (sideLen % 8 != 0) {
            row[sideLen / 8] = (byte) bits;
        }
    }
}


More information about the panama-dev mailing list