[vector] Perf difference between vector-unstable and vectorInstrinsics
Paul Sandoz
paul.sandoz at oracle.com
Tue Jul 7 17:08:04 UTC 2020
Maybe :-) It’s hard to answer the question with certainty. Feedback will help esp. while incubating.
However, the Vector API is designed to be a WYSIWYG API (once C2 gets at it) and assuming you are on the appropriate hardware.
We still have some work to do as is evident by Vladimir's inlining patch, further work required to elide bounds checks, optimize mask operations on AVX-512 hardware, and vectorize the transcendental functions (in the vector-unstable branch).
At the moment keeping the vector algorithm within one method and avoiding profile pollution should yield more stable results.
—
Since you presented the example as a benchmark I thought it more reliable to convert to JMH. But, I did not copy the code exactly (removing duplication) which might yield different results.
Goven
Paul.
> On Jul 7, 2020, at 5:38 AM, Wang Zhuo(Zhuoren) <zhuoren.wz at alibaba-inc.com> wrote:
>
> Thank you, Paul and Vladimir Ivanov.
> There is another question. When I worked with my colleagues to optimize their applications with Vector API, they usually concerned on performance regressions.
> Is it possible that we get good perf improvement on JMH but meet perf regression in real applications?
>
>
> Regards,
> Zhuoren
>
> ------------------------------------------------------------------
> From:Paul Sandoz <paul.sandoz at oracle.com>
> Sent At:2020 Jul. 7 (Tue.) 08:10
> To:Vladimir Ivanov <vladimir.x.ivanov at oracle.com>
> Cc:Sandler <zhuoren.wz at alibaba-inc.com>; panama-dev <panama-dev at openjdk.java.net>
> Subject:Re: [vector] Perf difference between vector-unstable and vectorInstrinsics
>
> I copied the source and transformed it into a JMH benchmark, see bellow. I don’t observe any difference between the two branches (JMH results in comments at end of source).
>
> For the SIMD implementation I did remove what appeared to be redundant mul/add operations in addition to the fma operations.
>
> And, I could not resist the temptation to use text blocks for the B64 data :-)
>
> I recommend using JMH rather than rolling your own quick measurements. It will provide a more reliable base from which to investigate issues with inlining and code generation (using perfasm on linux).
>
> Paul.
>
>
> import jdk.incubator.vector.FloatVector;
> import jdk.incubator.vector.IntVector;
> import jdk.incubator.vector.VectorOperators;
> import jdk.incubator.vector.VectorSpecies;
> import org.openjdk.jmh.annotations.Benchmark;
> import org.openjdk.jmh.annotations.BenchmarkMode;
> import org.openjdk.jmh.annotations.Fork;
> import org.openjdk.jmh.annotations.Measurement;
> import org.openjdk.jmh.annotations.Mode;
> import org.openjdk.jmh.annotations.OutputTimeUnit;
> import org.openjdk.jmh.annotations.Param;
> import org.openjdk.jmh.annotations.Scope;
> import org.openjdk.jmh.annotations.Setup;
> import org.openjdk.jmh.annotations.State;
> import org.openjdk.jmh.annotations.Warmup;
>
> import java.nio.ByteBuffer;
> import java.nio.ByteOrder;
> import java.nio.FloatBuffer;
> import java.util.Base64;
> import java.util.concurrent.TimeUnit;
> import java.util.function.IntFunction;
>
> @BenchmarkMode(Mode.Throughput)
> @OutputTimeUnit(TimeUnit.MILLISECONDS)
> @State(Scope.Benchmark)
> @Warmup(iterations = 10, time = 1)
> @Measurement(iterations = 10, time = 1)
> @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
> public class VectorSimilarity {
> static final VectorSpecies<Float> SPECIES256 = FloatVector.SPECIES_256;
> static final VectorSpecies<Float> SPECIES512 = FloatVector.SPECIES_512;
>
> private static String X1 = """
> L5GSwXhHpEH05mNBHnmcQMTw3EBnagFCW1DGQHe/nUFO1B1BlJOpwCBJ9j+RkY1B
> zqKeQSglN0Gy7krB5CSfQFzxB8Djn5nB2KNFwKcSRMGYzRQ7qMGWQZ0FF0FTceDA
> IKjxv/zhdkHFZMHB6hU4QZbo2cCAryRB+7OOQCxbfEHRtBlBxPG6P0BYSD+Pgz9B
> qzOLv/nVO8C9x5/BQOY/wTTIx0GfW1BBGv2lQQwdDcGCqBfB12t/QKUBoEEejIXB
> PN9kQWsFbEGsGcnBkqJkwKhLgr/IQZxAelAWQfcYpcFQv0HBeiGCQWExhEDrKAnB
> pAwBQV4bVcFpGNjAyDsNQVOc+0CSc4nBgG/ZQQGRccEXts9BKhYzQNK5+MAlU0DB
> zPGWwPGRCcEZC5/ADxOcv7lUkEBomM5BuqKiwV2MU8HNGHDBSB84QZRSyMB8RZlB
> VFdZQXSVgcBTQQBCdWa/QBQ0qkGILUW/6NA9QQnkmsG+5PPBj0UowT6nYD9cwpjA
> S/w5wTbX2UH8Gb5AR/HUQMTNAMJ9MN9AgHoqPbbUyUFbe47BBHANQWZJBsGBuPlB
> y94EQADeXsG5eOtBnA+yQCRka8EMcGLBjuoRwb4k7sAasB5Bmk/UwaI1akErp6xB
> q5G5wNo1E8KHa7tB3IiKQTCffcHphK1BTgJzwVY3JEEip/VAlmgXQSeKCsLEABs/
> n1/xwL5u58CgQY49ahUWQoAJjj1hhqBASXrrQb6nM0H2fY+/thtbQAQobMAohvXA
> xM3xv7xyqD+MvpDBrlDiQfBvPcGA8X5AQE4SwXhGx7+uLA1AxY8xu2mVjEE7KlFB
> ArveQFNMtUD3N7DB12BbQcyH4cFhSw3Bu5VWQeTW0z9o03TBxtMlQctp/8E/lLVA
> GUtTwZsGJMKv/R5A1HKVQV6RhsC1Ji5AcXLFQJd6f0HbB+e+ZDi8wV9tQ0FwCN/B
> +A89v2DrU0Bcpc5BglTeQH5dT0HePS9Al4XPwdA6YEFlueXAbWKSQSBWzkBy2RnC
> t9Yawl9b77+xgxBC9eCqQd8f0kFoBG9BVxrkQZh2QkHNW/zBEQiawLJEocDhutTA
> 8zEYwbIvEUIO1T9BmlOTwIhbNEDhrtlAVk9BQARQaj89NQNC6usGwDfQrkBSJrlA
> ON7FQQ8FqsEEc/TAY3zeQYsqUEHV8QPBHJoYQQdn5kGyCiJBlDMYQBBNoUFrxbw/
> NlmPP3B24j6ChIdBXk2bwdxdDMFQw1rA4hybQXTchr8d9wvBuCbLQSMKmMBH4RpB
> QIXePa5DT8IjgvtBgAetQZgGgMEprc1BAOeSPJ5XpEEMa0NBgX4uwX7XIsG2Ie06
> 88iqQSpJPsCAy9LBAGHkPw==""";
>
> private static String X2 = """
> 5R3ZwGPrxEFMKyNBLFSeQdYav0BQtDFCur7WQAgRYEGHFYC/MKZtvkiFUT+RNXfB
> VsGBP2KWSUCmAUTBIf+EQG57kMCtXo7BV1DuwLd98r+YzRQ7qKXNwBMSPUFNQffB
> PrxeQYw1t7/7JjFAKNaXP+cMSEG6GI5BuEx0wUANDMEvDqdAT9YEworQTEEiVBZB
> iMejQP7t67+iRwzB3HadQB1be0Ei5g5BMt+cQXvYTUHwZsLAuoy3QfrR6EFrIiHB
> 5X8Dwc8XbUH8Yr8/AvGEwa5GkUH3F5tAP8YJQTiDyz+gKsRAFl/rwDxJuUAPyyxB
> vg2gQU6bjMEPEa7Bz6wYQpQy7MDF5LvB8HP+QCJdicHQDjpC6RpWQcGeY8FMK6vB
> oeUjQcPYmUG2QmRBBI0nwScESsGMAcxBvRmawRL2A8IByKNAgTQBQuxdDEGq8JBB
> HJWmQSBDfz8sLe9BE3gFwTdCPkHEaMxBhX8Xwe7BCcE/783Bt6EHwdpbpkHc5L/B
> CPzRwUdIQUEd/k3AoGNcQQwNmMEyuKRBtnWlwdCBAUI5Y5DBwOZYvdI+MsEu/ixB
> npMrwRtYt8FECytC6JjEQW3RHcBtfn3B+sgQQcyQKcEI5ytByvw2wPZdaUH+aqLA
> QFQ+QPi4REBF/9lBCvJNQTdlEcIAMbzBtD+hwZWufsAEjus/YRyjwR1YuMHj0ZhB
> a4w+QORAhMEq9qdB/L8JQrjhyUAJBeBAKqoIQUnAq0GsLFdBkfrvQHc1zMHH6THB
> eggSwaJIOsAawwBBDDWqPwrAlkBYDqe/maUcQabhwsFF2VBBxY8xu5aMQUFDkHVB
> KhRRwHhgWsEA5jXBlh9NQVMaT0CWlhTAroaFQRyciUHQlp7BF4trQa8unsE4TfI+
> 9XLJQDNpLcIXLZdAuX2MwShiTsFcQh5BrHMqQVI1+UBWe4fBAzi0wfe11UFAIjq9
> Y1iAQDxrTsEY6plB/JiXQfjFwkHkYGRBVNOhwCMxtEFbqZTA378WQeA/Sb+FrSXC
> qlYywtb5SsDcqlZBk1EtQZ/RREHZIxG/kcv8QekDIkHPsDXCBL4VQHN8CMGtNvvA
> C3YwweUuAkKkJCnANEtVQG9z/0DrwyTBQ9hnwWX3kMEdLB1CvIlKwQ0IO0HK1ErB
> vdRQQVpjMMCJDI/Bb4X8QYVipEGpG2nBeLGUvmBlBT7ISgRB4iGAQUunkkFDFLm/
> HNaqPzKTVkCITJG/XzlYwbj0XcGD60PBbpLwQbvrs8Az8RXB4ubxQXh/HEDtXLU/
> kONrwVBs4MGc2X1BJaHkQd0ByEAKXLJBTq7JwPPkJUGJIIRBlh57wX3FjcC2Ie06
> 0Qc6Qal5xcCfqQrCl7edQQ==""";
>
> static float[] V1 = parseBase64ToVector(X1);
> static float[] V2 = parseBase64ToVector(X2);
>
> public static float[] parseArray(byte[] input) {
> if (input == null) {
> return null;
> }
> FloatBuffer src = ByteBuffer.wrap(input).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
> FloatBuffer dst = FloatBuffer.allocate(src.capacity());
> dst.put(src);
> return dst.array();
> }
>
> public static float[] parseBase64ToVector(String vectorBase64) {
> return parseArray(Base64.getMimeDecoder().decode(vectorBase64));
> }
>
> public static float getCosineSimilaritySIMDFMA(float[] queryVector, float[] vector) {
> FloatVector vecX, vecY, vecSum, xSquareV, ySquareV;
> vecSum = FloatVector.zero(SPECIES256);
> xSquareV = FloatVector.zero(SPECIES256);
> ySquareV = FloatVector.zero(SPECIES256);
>
> int i = 0;
> int upperBound = SPECIES256.loopBound(queryVector.length);
> for (; i < upperBound; i += SPECIES256.length()) {
> vecX = FloatVector.fromArray(SPECIES256, queryVector, i);
> vecY = FloatVector.fromArray(SPECIES256, vector, i);
>
> vecSum = vecX.fma(vecY, vecSum);
>
> xSquareV = vecX.fma(vecX, xSquareV);
>
> ySquareV = vecY.fma(vecY, ySquareV);
>
> vecX.intoArray(vector, i);
> vecY.intoArray(queryVector, i);
> }
> float sum = vecSum.reduceLanes(VectorOperators.ADD);
> float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
> float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
> for (; i < queryVector.length; i++) {
> sum += queryVector[i] * vector[i];
> xSquare += queryVector[i] * queryVector[i];
> ySquare += vector[i] * vector[i];
> }
> if (ySquare < 1e-8) {
> return 0;
> }
> return (float) (sum / Math.sqrt(xSquare * ySquare));
> }
>
> public static float getCosineSimilaritySIMD(float[] queryVector, float[] vector) {
> FloatVector vecX, vecY, vecSum, xSquareV, ySquareV;
> vecSum = FloatVector.zero(SPECIES256);
> xSquareV = FloatVector.zero(SPECIES256);
> ySquareV = FloatVector.zero(SPECIES256);
>
> int i = 0;
> int upperBound = SPECIES256.loopBound(queryVector.length);
> for (; i < upperBound; i += SPECIES256.length()) {
> vecX = FloatVector.fromArray(SPECIES256, queryVector, i);
> vecY = FloatVector.fromArray(SPECIES256, vector, i);
>
> vecSum = vecX.mul(vecY).add(vecSum);
>
> xSquareV = vecX.mul(vecX).add(xSquareV);
>
> ySquareV = vecY.mul(vecY).add(ySquareV);
>
> vecX.intoArray(vector, i);
> vecY.intoArray(queryVector, i);
> }
> float sum = vecSum.reduceLanes(VectorOperators.ADD);
> float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
> float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
> for (; i < queryVector.length; i++) {
> sum += queryVector[i] * vector[i];
> xSquare += queryVector[i] * queryVector[i];
> ySquare += vector[i] * vector[i];
> }
> if (ySquare < 1e-8) {
> return 0;
> }
> return (float) (sum / Math.sqrt(xSquare * ySquare));
> }
>
> public static float getCosineSimilarityScalarFMA(float[] queryVector, float[] vector) {
> float sum = 0;
> float xSquare = 0;
> float ySquare = 0;
> for (int i = 0; i < queryVector.length; i++) {
> float qv = queryVector[i];
> float v = vector[i];
>
> sum = Math.fma(qv, v, sum);
>
> xSquare = Math.fma(qv, qv, xSquare);
>
> ySquare = Math.fma(v, v, ySquare);
> }
> if (ySquare < 1e-8) {
> return 0;
> }
> return (float) (sum / Math.sqrt(xSquare * ySquare));
> }
>
> public static float getCosineSimilarityScalar(float[] queryVector, float[] vector) {
> float sum = 0;
> float xSquare = 0;
> float ySquare = 0;
> for (int i = 0; i < queryVector.length; i++) {
> float qv = queryVector[i];
> float v = vector[i];
>
> sum += qv * v;
>
> xSquare += qv * qv;
>
> ySquare += v * v;
> }
> if (ySquare < 1e-8) {
> return 0;
> }
> return (float) (sum / Math.sqrt(xSquare * ySquare));
> }
>
> @Benchmark
> public float vector_fma() {
> return getCosineSimilaritySIMDFMA(V1, V2);
> }
>
> @Benchmark
> public float vector() {
> return getCosineSimilaritySIMD(V1, V2);
> }
>
> @Benchmark
> public float scalar_fma() {
> return getCosineSimilarityScalarFMA(V1, V2);
> }
>
> @Benchmark
> public float scalar() {
> return getCosineSimilarityScalar(V1, V2);
> }
> }
>
> /*
> vectorIntrinsics
> --
> Benchmark Mode Cnt Score Error Units
> VectorSimilarity.scalar thrpt 10 3876.208 ± 43.784 ops/ms
> VectorSimilarity.scalar_fma thrpt 10 4511.374 ± 15.802 ops/ms
> VectorSimilarity.vector thrpt 10 21819.561 ± 229.776 ops/ms
> VectorSimilarity.vector_fma thrpt 10 22866.804 ± 263.219 ops/ms
>
> vector-unstable
> --
> Benchmark Mode Cnt Score Error Units
> VectorSimilarity.scalar thrpt 10 3924.974 ± 17.718 ops/ms
> VectorSimilarity.scalar_fma thrpt 10 4473.268 ± 54.080 ops/ms
> VectorSimilarity.vector thrpt 10 21511.554 ± 218.139 ops/ms
> VectorSimilarity.vector_fma thrpt 10 22756.076 ± 157.729 ops/ms
>
> */
>
> > On Jul 6, 2020, at 2:17 PM, Vladimir Ivanov <vladimir.x.ivanov at oracle.com> wrote:
> >
> > Hi Zhuoren,
> >
> > I haven't investigated what is actually causes the difference, but seeing reduceLanes() calls [1] after the loop I suspect it is caused by inlining issues. (You can verify that by looking at -XX:+PrintInlining output.)
> >
> > In vectorIntrinsics branch there's a fix integrated recently which makes inlining of vector operations more robust:
> > https://hg.openjdk.java.net/panama/dev/rev/5b601a43ac88
> >
> > Best regards,
> > Vladimir Ivanov
> >
> > [1]
> > > float sum = vecSum.reduceLanes(VectorOperators.ADD);
> > > float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
> > > float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
> >
> >
> > On 06.07.2020 14:42, Wang Zhuo(Zhuoren) wrote:
> >> Hi, I am implementing Cosine Distance using Vector API, while I found that performance of my algorithm on vector-unstable is much better than vectorInstrinsics
> >> On vectorInstrinsics:
> >> normal time used:965
> >> vector time used:2529
> >> On vector-unstable:
> >> normal time used:968
> >> vector time used:226
> >> The numbers are time (in ms), the smaller the better.
> >> I wonder if there are some differences between the two branches that cause this perf difference?
> >> The test code is below, please check.
> >> Command I used to run:
> >> java --add-modules=jdk.incubator.vector -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 VectorSimilarity
> >> import jdk.incubator.vector.*;
> >> import java.util.Base64;
> >> import java.util.concurrent.TimeUnit;
> >> public class VectorSimilarity {
> >> static final VectorSpecies<Float> SPECIES256 = FloatVector.SPECIES_256;
> >> static final VectorSpecies<Float> SPECIES512 = FloatVector.SPECIES_512;
> >> private static String x1 = "L5GSwXhHpEH05mNBHnmcQMTw3EBnagFCW1DGQHe/nUFO1B1BlJOpwCBJ9j" +
> >> "+RkY1BzqKeQSglN0Gy7krB5CSfQFzxB8Djn5nB2KNFwKcSRMGYzRQ7qMGWQZ0FF0FTceDAIKjxv/zhdkHFZMHB6hU4QZbo2cCAryRB+7OOQCxbfEHRtBlBxPG6P0BYSD+Pgz9BqzOLv/nVO8C9x5/BQOY/wTTIx0GfW1BBGv2lQQwdDcGCqBfB12t/QKUBoEEejIXBPN9kQWsFbEGsGcnBkqJkwKhLgr/IQZxAelAWQfcYpcFQv0HBeiGCQWExhEDrKAnBpAwBQV4bVcFpGNjAyDsNQVOc+0CSc4nBgG/ZQQGRccEXts9BKhYzQNK5+MAlU0DBzPGWwPGRCcEZC5/ADxOcv7lUkEBomM5BuqKiwV2MU8HNGHDBSB84QZRSyMB8RZlBVFdZQXSVgcBTQQBCdWa/QBQ0qkGILUW/6NA9QQnkmsG+5PPBj0UowT6nYD9cwpjAS/w5wTbX2UH8Gb5AR/HUQMTNAMJ9MN9AgHoqPbbUyUFbe47BBHANQWZJBsGBuPlBy94EQADeXsG5eOtBnA+yQCRka8EMcGLBjuoRwb4k7sAasB5Bmk/UwaI1akErp6xBq5G5wNo1E8KHa7tB3IiKQTCffcHphK1BTgJzwVY3JEEip/VAlmgXQSeKCsLEABs/n1/xwL5u58CgQY49ahUWQoAJjj1hhqBASXrrQb6nM0H2fY+/thtbQAQobMAohvXAxM3xv7xyqD+MvpDBrlDiQfBvPcGA8X5AQE4SwXhGx7+uLA1AxY8xu2mVjEE7KlFBArveQFNMtUD3N7DB12BbQcyH4cFhSw3Bu5VWQeTW0z9o03TBxtMlQctp/8E/lLVAGUtTwZsGJMKv/R5A1HKVQV6RhsC1Ji5AcXLFQJd6f0HbB+e+ZDi8wV9tQ0FwCN/B+A89v2DrU0Bcpc5BglTeQH5dT0HePS9Al4XPwdA6YEFlueXAbWKSQSBWzkBy2RnCt9Yawl9b77+xgxBC9eCqQd8f0kFoBG9BVxrkQZh2QkHNW/zBEQiawLJEocDhutTA8zEYwbIvEUIO1T9BmlOTwIhbNEDhrtlAVk9BQARQaj89NQNC6usGwDfQrkBSJrlAON7FQQ8FqsEEc/TAY3zeQYsqUEHV8QPBHJoYQQdn5kGyCiJBlDMYQBBNoUFrxbw/NlmPP3B24j6ChIdBXk2bwdxdDMFQw1rA4hybQXTchr8d9wvBuCbLQSMKmMBH4RpBQIXePa5DT8IjgvtBgAetQZgGgMEprc1BAOeSPJ5XpEEMa0NBgX4uwX7XIsG2Ie0688iqQSpJPsCAy9LBAGHkPw=="; private static String x2 = "5R3ZwGPrxEFMKyNBLFSeQdYav0BQtDFCur7WQAgRYEGHFYC/MKZtvkiFUT+RNXfBVsGBP2KWSUCmAUTBIf+EQG57kMCtXo7BV1DuwLd98r+YzRQ7qKXNwBMSPUFNQffBPrxeQYw1t7/7JjFAKNaXP+cMSEG6GI5BuEx0wUANDMEvDqdAT9YEworQTEEiVBZBiMejQP7t67+iRwzB3HadQB1be0Ei5g5BMt+cQXvYTUHwZsLAuoy3QfrR6EFrIiHB5X8Dwc8XbUH8Yr8/AvGEwa5GkUH3F5tAP8YJQTiDyz+gKsRAFl/rwDxJuUAPyyxBvg2gQU6bjMEPEa7Bz6wYQpQy7MDF5LvB8HP+QCJdicHQDjpC6RpWQcGeY8FMK6vBoeUjQcPYmUG2QmRBBI0nwScESsGMAcxBvRmawRL2A8IByKNAgTQBQuxdDEGq8JBBHJWmQSBDfz8sLe9BE3gFwTdCPkHEaMxBhX8Xwe7BCcE/783Bt6EHwdpbpkHc5L/BCPzRwUdIQUEd/k3AoGNcQQwNmMEyuKRBtnWlwdCBAUI5Y5DBwOZYvdI+MsEu/ixBnpMrwRtYt8FECytC6JjEQW3RHcBtfn3B+sgQQcyQKcEI5ytByvw2wPZdaUH+aqLAQFQ+QPi4REBF/9lBCvJNQTdlEcIAMbzBtD+hwZWufsAEjus/YRyjwR1YuMHj0ZhBa4w+QORAhMEq9qdB/L8JQrjhyUAJBeBAKqoIQUnAq0GsLFdBkfrvQHc1zMHH6THBeggSwaJIOsAawwBBDDWqPwrAlkBYDqe/maUcQabhwsFF2VBBxY8xu5aMQUFDkHVBKhRRwHhgWsEA5jXBlh9NQVMaT0CWlhTAroaFQRyciUHQlp7BF4trQa8unsE4TfI+9XLJQDNpLcIXLZdAuX2MwShiTsFcQh5BrHMqQVI1+UBWe4fBAzi0wfe11UFAIjq9Y1iAQDxrTsEY6plB/JiXQfjFwkHkYGRBVNOhwCMxtEFbqZTA378WQeA/Sb+FrSXCqlYywtb5SsDcqlZBk1EtQZ/RREHZIxG/kcv8QekDIkHPsDXCBL4VQHN8CMGtNvvAC3YwweUuAkKkJCnANEtVQG9z/0DrwyTBQ9hnwWX3kMEdLB1CvIlKwQ0IO0HK1ErBvdRQQVpjMMCJDI/Bb4X8QYVipEGpG2nBeLGUvmBlBT7ISgRB4iGAQUunkkFDFLm/HNaqPzKTVkCITJG/XzlYwbj0XcGD60PBbpLwQbvrs8Az8RXB4ubxQXh/HEDtXLU/kONrwVBs4MGc2X1BJaHkQd0ByEAKXLJBTq7JwPPkJUGJIIRBlh57wX3FjcC2Ie060Qc6Qal5xcCfqQrCl7edQQ==";
> >> static float[] v1 = parseBase64ToVector(x1);
> >> static float[] v2 = parseBase64ToVector(x2);
> >> public static float[] parseArray(byte[] input) {
> >> if (input == null) {
> >> return null;
> >> }
> >> float[] floatArr = new float[input.length / 4];
> >> for (int i = 0; i < floatArr.length; i++) {
> >> int l;
> >> l = input[i << 2];
> >> l &= 0xff;
> >> l |= ((long) input[(i << 2) + 1] << 8);
> >> l &= 0xffff;
> >> l |= ((long) input[(i << 2) + 2] << 16);
> >> l &= 0xffffff;
> >> l |= ((long) input[(i << 2) + 3] << 24);
> >> floatArr[i] = Float.intBitsToFloat(l);
> >> }
> >> return floatArr;
> >> }
> >> public static float[] parseBase64ToVector(String vectorBase64) {
> >> return parseArray(Base64.getDecoder().decode(vectorBase64));
> >> } public static float getCosineSimilaritySIMD(float[] queryVector, float[] vector) {
> >> FloatVector vecX, vecY, vecSum, xSquareV, ySquareV;
> >> vecSum = FloatVector.zero(SPECIES256);
> >> xSquareV = FloatVector.zero(SPECIES256);
> >> ySquareV = FloatVector.zero(SPECIES256);;
> >> int i= 0;
> >> for (i = 0; i + (SPECIES256.length()) <= queryVector.length; i += SPECIES256.length()) {
> >> vecX = FloatVector.fromArray(SPECIES256, queryVector, i);
> >> vecY = FloatVector.fromArray(SPECIES256, vector, i);
> >> vecSum = vecX.mul(vecY).add(vecSum);
> >> vecSum = vecX.fma(vecY, vecSum);
> >> xSquareV = vecX.fma(vecX, xSquareV);
> >> ySquareV = vecY.fma(vecY, ySquareV);
> >> xSquareV = vecX.mul(vecX).add(xSquareV);
> >> ySquareV = vecY.mul(vecY).add(ySquareV);
> >> vecX.intoArray(vector, i);
> >> vecY.intoArray(queryVector, i);
> >> }
> >> float sum = vecSum.reduceLanes(VectorOperators.ADD);
> >> float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
> >> float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
> >> for (; i < queryVector.length; i++) {
> >> sum += queryVector[i] * vector[i];
> >> xSquare += queryVector[i] * queryVector[i];
> >> ySquare += vector[i] * vector[i];
> >> }
> >> if (ySquare < 1e-8) {
> >> return 0;
> >> }
> >> return (float)(sum / Math.sqrt(xSquare * ySquare));
> >> } public static float getCosineSimilarityScalar(float[] queryVector, float[] vector) {
> >> float sum = 0;
> >> float xSquare = 0;
> >> float ySquare = 0;
> >> for (int i = 0; i < queryVector.length; i++) {
> >> //queryVector[i] = vector[i];
> >> sum += (float)(queryVector[i] * vector[i]);
> >> xSquare += (float)(queryVector[i] * queryVector[i]);
> >> ySquare += (float)(vector[i] * vector[i]);
> >> }
> >> if (ySquare < 1e-8) {
> >> return 0;
> >> }
> >> return (float)(sum / Math.sqrt(xSquare * ySquare));
> >> }
> >> public static void main(String[] args) {
> >> long t1, t2;
> >> for (int i = 0; i < 100000; i++) {
> >> getCosineSimilaritySIMD(v1, v2);
> >> getCosineSimilarityScalar(v1, v2);
> >> }
> >> System.out.println("normal result " + getCosineSimilarityScalar(v1, v2) + " vec result " + getCosineSimilaritySIMD(v1, v2));
> >> t1 = System.currentTimeMillis();
> >> for (int i = 0; i < 2000000; i++) {
> >> getCosineSimilarityScalar(v1, v2);
> >> }
> >> System.out.println("normal time used:" + (System.currentTimeMillis() - t1));
> >> t2 = System.currentTimeMillis();
> >> for (int i = 0; i < 2000000; i++) {
> >> getCosineSimilaritySIMD(v1, v2);
> >> }
> >> System.out.println("vector time used:" + (System.currentTimeMillis() - t2));
> >> }
> >> }
> >> Regards,
> >> Zhuoren
More information about the panama-dev
mailing list