BLAS and Vector API
Paul Sandoz
paul.sandoz at oracle.com
Mon Jan 4 23:23:48 UTC 2021
Looks like the attachment was removed by the mailing list bot, adding inline for the benefit of those not explicitly in the To/CC list.
Paul.
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorSpecies;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread) // Startup / Teardown / time done with single thread
@Warmup(iterations = 10, time = 1)
@Measurement(iterations = 10, time = 1)
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector", "--enable-preview"})
public class TestKernel {
int k_iter;
float[] a, b, c;
@Setup
public void setup() {
k_iter = 100;
a = new float[k_iter * 6];
fill(a);
b = new float[k_iter * 16];
fill(b);
c = new float[2048 * 6];
}
@Benchmark
@CompilerControl(value = CompilerControl.Mode.DONT_INLINE)
public float[] vector() {
kernel_6x16(k_iter, a, 0, b, 0, c, 0, 2048, 1);
return c;
}
static final VectorSpecies<Float> SPECIES_256 = FloatVector.SPECIES_256;
static void kernel_6x16(
int k_iter,
float[] a,
int a_offset,
float[] b,
int b_offset,
float[] c,
int c_offset,
int rs_c,
int cs_c) {
// For now, we ONLY accept row-major layout:
if (cs_c != 1) {
throw new IllegalArgumentException("Row major storage only");
}
FloatVector AxB_0L, AxB_0R;
AxB_0L = FloatVector.zero(SPECIES_256);
AxB_0R = FloatVector.zero(SPECIES_256);
FloatVector AxB_1L, AxB_1R;
AxB_1L = FloatVector.zero(SPECIES_256);
AxB_1R = FloatVector.zero(SPECIES_256);
FloatVector AxB_2L, AxB_2R;
AxB_2L = FloatVector.zero(SPECIES_256);
AxB_2R = FloatVector.zero(SPECIES_256);
FloatVector AxB_3L, AxB_3R;
AxB_3L = FloatVector.zero(SPECIES_256);
AxB_3R = FloatVector.zero(SPECIES_256);
FloatVector AxB_4L, AxB_4R;
AxB_4L = FloatVector.zero(SPECIES_256);
AxB_4R = FloatVector.zero(SPECIES_256);
FloatVector AxB_5L, AxB_5R;
AxB_5L = FloatVector.zero(SPECIES_256);
AxB_5R = FloatVector.zero(SPECIES_256);
// Cycle through two rows of B, then repeat:
int b_upper = b_offset + k_iter * 16;
for (int b_iter = b_offset, a_iter = a_offset; b_iter < b_upper; b_iter += 16, a_iter += 6) {
// for (int k = 0; k < k_iter; k++) {
// int b_iter = b_offset + k * 16;
// int a_iter = a_offset + k * 6;
FloatVector B_0L = FloatVector.fromArray(SPECIES_256, b, b_iter);
FloatVector B_0R = FloatVector.fromArray(SPECIES_256, b, b_iter + 8);
FloatVector vbr_a0 = FloatVector.broadcast(SPECIES_256,
a[a_iter + 0]);
FloatVector vbr_a1 = FloatVector.broadcast(SPECIES_256,
a[a_iter + 1]);
// Row 0:
AxB_0L = B_0L.fma(vbr_a0, AxB_0L);
AxB_0R = B_0R.fma(vbr_a0, AxB_0R);
// Row 1:
AxB_1L = B_0L.fma(vbr_a1, AxB_1L);
AxB_1R = B_0R.fma(vbr_a1, AxB_1R);
vbr_a0 = FloatVector.broadcast(SPECIES_256,
a[a_iter + 2]);
vbr_a1 = FloatVector.broadcast(SPECIES_256,
a[a_iter + 3]);
// Row 2:
AxB_2L = B_0L.fma(vbr_a0, AxB_2L);
AxB_2R = B_0R.fma(vbr_a0, AxB_2R);
// Row 3:
AxB_3L = B_0L.fma(vbr_a1, AxB_3L);
AxB_3R = B_0R.fma(vbr_a1, AxB_3R);
vbr_a0 = FloatVector.broadcast(SPECIES_256,
a[a_iter + 4]);
vbr_a1 = FloatVector.broadcast(SPECIES_256,
a[a_iter + 5]);
// Row 4:
AxB_4L = B_0L.fma(vbr_a0, AxB_4L);
AxB_4R = B_0R.fma(vbr_a0, AxB_4R);
// Row 5:
AxB_5L = B_0L.fma(vbr_a1, AxB_5L);
AxB_5R = B_0R.fma(vbr_a1, AxB_5R);
}
int c_row_iter = c_offset;
AxB_0L.intoArray(c, c_row_iter);
AxB_0R.intoArray(c, c_row_iter + 8);
c_row_iter += rs_c;
AxB_1L.intoArray(c, c_row_iter);
AxB_1R.intoArray(c, c_row_iter + 8);
c_row_iter += rs_c;
AxB_2L.intoArray(c, c_row_iter);
AxB_2R.intoArray(c, c_row_iter + 8);
c_row_iter += rs_c;
AxB_3L.intoArray(c, c_row_iter);
AxB_3R.intoArray(c, c_row_iter + 8);
c_row_iter += rs_c;
AxB_4L.intoArray(c, c_row_iter);
AxB_4R.intoArray(c, c_row_iter + 8);
c_row_iter += rs_c;
AxB_5L.intoArray(c, c_row_iter);
AxB_5R.intoArray(c, c_row_iter + 8);
}
static void fill(float[] a) {
for (int i = 0; i < a.length; i++) {
a[i] = ThreadLocalRandom.current().nextFloat();
}
}
}
> On Dec 23, 2020, at 9:33 AM, Paul Sandoz <paul.sandoz at oracle.com> wrote:
>
> Hi Ludovic,
>
> Thanks for sharing, I shall look at this in more detail.
>
> In the interim, you may be interested in the following paper
>
> https://arxiv.org/pdf/1904.05717v1.pdf
>
> Attaining peak performance is as much about data movement as it is about the kernel.
>
> From this paper and BLIS Flame I wrote a kernel using the Vector API with some help from a colleague, supporting C += A * B, where C is updated in place (counter intuitively, to me at least initially, the columns of A are multiple by the rows of B).
>
> Attached is a kernel optimized for AVX2. The code gen is not bad. Register allocation is good. Bounds checking could be improved. Feel free to use it as you see fit.
>
> A useful experiment would be to wrap this kernel around the higher loops (with data movement) and see how the whole implementation behaves.
>
> —
>
> On alignment. The problem is higher alignment of Java arrays is not stable. The array could be moved by GC. Further, it gets more difficult when there are two or more arrays at different alignments. The best way to fix this IMHO is to support the Panama Memory API using native memory that is explicitly aligned.
>
> Paul.
More information about the panama-dev
mailing list