[vector] Perf difference between vector-unstable and vectorInstrinsics
Paul Sandoz
paul.sandoz at oracle.com
Tue Jul 7 00:07:50 UTC 2020
I copied the source and transformed it into a JMH benchmark, see bellow. I don’t observe any difference between the two branches (JMH results in comments at end of source).
For the SIMD implementation I did remove what appeared to be redundant mul/add operations in addition to the fma operations.
And, I could not resist the temptation to use text blocks for the B64 data :-)
I recommend using JMH rather than rolling your own quick measurements. It will provide a more reliable base from which to investigate issues with inlining and code generation (using perfasm on linux).
Paul.
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Base64;
import java.util.concurrent.TimeUnit;
import java.util.function.IntFunction;
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 10, time = 1)
@Measurement(iterations = 10, time = 1)
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public class VectorSimilarity {
static final VectorSpecies<Float> SPECIES256 = FloatVector.SPECIES_256;
static final VectorSpecies<Float> SPECIES512 = FloatVector.SPECIES_512;
private static String X1 = """
L5GSwXhHpEH05mNBHnmcQMTw3EBnagFCW1DGQHe/nUFO1B1BlJOpwCBJ9j+RkY1B
zqKeQSglN0Gy7krB5CSfQFzxB8Djn5nB2KNFwKcSRMGYzRQ7qMGWQZ0FF0FTceDA
IKjxv/zhdkHFZMHB6hU4QZbo2cCAryRB+7OOQCxbfEHRtBlBxPG6P0BYSD+Pgz9B
qzOLv/nVO8C9x5/BQOY/wTTIx0GfW1BBGv2lQQwdDcGCqBfB12t/QKUBoEEejIXB
PN9kQWsFbEGsGcnBkqJkwKhLgr/IQZxAelAWQfcYpcFQv0HBeiGCQWExhEDrKAnB
pAwBQV4bVcFpGNjAyDsNQVOc+0CSc4nBgG/ZQQGRccEXts9BKhYzQNK5+MAlU0DB
zPGWwPGRCcEZC5/ADxOcv7lUkEBomM5BuqKiwV2MU8HNGHDBSB84QZRSyMB8RZlB
VFdZQXSVgcBTQQBCdWa/QBQ0qkGILUW/6NA9QQnkmsG+5PPBj0UowT6nYD9cwpjA
S/w5wTbX2UH8Gb5AR/HUQMTNAMJ9MN9AgHoqPbbUyUFbe47BBHANQWZJBsGBuPlB
y94EQADeXsG5eOtBnA+yQCRka8EMcGLBjuoRwb4k7sAasB5Bmk/UwaI1akErp6xB
q5G5wNo1E8KHa7tB3IiKQTCffcHphK1BTgJzwVY3JEEip/VAlmgXQSeKCsLEABs/
n1/xwL5u58CgQY49ahUWQoAJjj1hhqBASXrrQb6nM0H2fY+/thtbQAQobMAohvXA
xM3xv7xyqD+MvpDBrlDiQfBvPcGA8X5AQE4SwXhGx7+uLA1AxY8xu2mVjEE7KlFB
ArveQFNMtUD3N7DB12BbQcyH4cFhSw3Bu5VWQeTW0z9o03TBxtMlQctp/8E/lLVA
GUtTwZsGJMKv/R5A1HKVQV6RhsC1Ji5AcXLFQJd6f0HbB+e+ZDi8wV9tQ0FwCN/B
+A89v2DrU0Bcpc5BglTeQH5dT0HePS9Al4XPwdA6YEFlueXAbWKSQSBWzkBy2RnC
t9Yawl9b77+xgxBC9eCqQd8f0kFoBG9BVxrkQZh2QkHNW/zBEQiawLJEocDhutTA
8zEYwbIvEUIO1T9BmlOTwIhbNEDhrtlAVk9BQARQaj89NQNC6usGwDfQrkBSJrlA
ON7FQQ8FqsEEc/TAY3zeQYsqUEHV8QPBHJoYQQdn5kGyCiJBlDMYQBBNoUFrxbw/
NlmPP3B24j6ChIdBXk2bwdxdDMFQw1rA4hybQXTchr8d9wvBuCbLQSMKmMBH4RpB
QIXePa5DT8IjgvtBgAetQZgGgMEprc1BAOeSPJ5XpEEMa0NBgX4uwX7XIsG2Ie06
88iqQSpJPsCAy9LBAGHkPw==""";
private static String X2 = """
5R3ZwGPrxEFMKyNBLFSeQdYav0BQtDFCur7WQAgRYEGHFYC/MKZtvkiFUT+RNXfB
VsGBP2KWSUCmAUTBIf+EQG57kMCtXo7BV1DuwLd98r+YzRQ7qKXNwBMSPUFNQffB
PrxeQYw1t7/7JjFAKNaXP+cMSEG6GI5BuEx0wUANDMEvDqdAT9YEworQTEEiVBZB
iMejQP7t67+iRwzB3HadQB1be0Ei5g5BMt+cQXvYTUHwZsLAuoy3QfrR6EFrIiHB
5X8Dwc8XbUH8Yr8/AvGEwa5GkUH3F5tAP8YJQTiDyz+gKsRAFl/rwDxJuUAPyyxB
vg2gQU6bjMEPEa7Bz6wYQpQy7MDF5LvB8HP+QCJdicHQDjpC6RpWQcGeY8FMK6vB
oeUjQcPYmUG2QmRBBI0nwScESsGMAcxBvRmawRL2A8IByKNAgTQBQuxdDEGq8JBB
HJWmQSBDfz8sLe9BE3gFwTdCPkHEaMxBhX8Xwe7BCcE/783Bt6EHwdpbpkHc5L/B
CPzRwUdIQUEd/k3AoGNcQQwNmMEyuKRBtnWlwdCBAUI5Y5DBwOZYvdI+MsEu/ixB
npMrwRtYt8FECytC6JjEQW3RHcBtfn3B+sgQQcyQKcEI5ytByvw2wPZdaUH+aqLA
QFQ+QPi4REBF/9lBCvJNQTdlEcIAMbzBtD+hwZWufsAEjus/YRyjwR1YuMHj0ZhB
a4w+QORAhMEq9qdB/L8JQrjhyUAJBeBAKqoIQUnAq0GsLFdBkfrvQHc1zMHH6THB
eggSwaJIOsAawwBBDDWqPwrAlkBYDqe/maUcQabhwsFF2VBBxY8xu5aMQUFDkHVB
KhRRwHhgWsEA5jXBlh9NQVMaT0CWlhTAroaFQRyciUHQlp7BF4trQa8unsE4TfI+
9XLJQDNpLcIXLZdAuX2MwShiTsFcQh5BrHMqQVI1+UBWe4fBAzi0wfe11UFAIjq9
Y1iAQDxrTsEY6plB/JiXQfjFwkHkYGRBVNOhwCMxtEFbqZTA378WQeA/Sb+FrSXC
qlYywtb5SsDcqlZBk1EtQZ/RREHZIxG/kcv8QekDIkHPsDXCBL4VQHN8CMGtNvvA
C3YwweUuAkKkJCnANEtVQG9z/0DrwyTBQ9hnwWX3kMEdLB1CvIlKwQ0IO0HK1ErB
vdRQQVpjMMCJDI/Bb4X8QYVipEGpG2nBeLGUvmBlBT7ISgRB4iGAQUunkkFDFLm/
HNaqPzKTVkCITJG/XzlYwbj0XcGD60PBbpLwQbvrs8Az8RXB4ubxQXh/HEDtXLU/
kONrwVBs4MGc2X1BJaHkQd0ByEAKXLJBTq7JwPPkJUGJIIRBlh57wX3FjcC2Ie06
0Qc6Qal5xcCfqQrCl7edQQ==""";
static float[] V1 = parseBase64ToVector(X1);
static float[] V2 = parseBase64ToVector(X2);
public static float[] parseArray(byte[] input) {
if (input == null) {
return null;
}
FloatBuffer src = ByteBuffer.wrap(input).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
FloatBuffer dst = FloatBuffer.allocate(src.capacity());
dst.put(src);
return dst.array();
}
public static float[] parseBase64ToVector(String vectorBase64) {
return parseArray(Base64.getMimeDecoder().decode(vectorBase64));
}
public static float getCosineSimilaritySIMDFMA(float[] queryVector, float[] vector) {
FloatVector vecX, vecY, vecSum, xSquareV, ySquareV;
vecSum = FloatVector.zero(SPECIES256);
xSquareV = FloatVector.zero(SPECIES256);
ySquareV = FloatVector.zero(SPECIES256);
int i = 0;
int upperBound = SPECIES256.loopBound(queryVector.length);
for (; i < upperBound; i += SPECIES256.length()) {
vecX = FloatVector.fromArray(SPECIES256, queryVector, i);
vecY = FloatVector.fromArray(SPECIES256, vector, i);
vecSum = vecX.fma(vecY, vecSum);
xSquareV = vecX.fma(vecX, xSquareV);
ySquareV = vecY.fma(vecY, ySquareV);
vecX.intoArray(vector, i);
vecY.intoArray(queryVector, i);
}
float sum = vecSum.reduceLanes(VectorOperators.ADD);
float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
for (; i < queryVector.length; i++) {
sum += queryVector[i] * vector[i];
xSquare += queryVector[i] * queryVector[i];
ySquare += vector[i] * vector[i];
}
if (ySquare < 1e-8) {
return 0;
}
return (float) (sum / Math.sqrt(xSquare * ySquare));
}
public static float getCosineSimilaritySIMD(float[] queryVector, float[] vector) {
FloatVector vecX, vecY, vecSum, xSquareV, ySquareV;
vecSum = FloatVector.zero(SPECIES256);
xSquareV = FloatVector.zero(SPECIES256);
ySquareV = FloatVector.zero(SPECIES256);
int i = 0;
int upperBound = SPECIES256.loopBound(queryVector.length);
for (; i < upperBound; i += SPECIES256.length()) {
vecX = FloatVector.fromArray(SPECIES256, queryVector, i);
vecY = FloatVector.fromArray(SPECIES256, vector, i);
vecSum = vecX.mul(vecY).add(vecSum);
xSquareV = vecX.mul(vecX).add(xSquareV);
ySquareV = vecY.mul(vecY).add(ySquareV);
vecX.intoArray(vector, i);
vecY.intoArray(queryVector, i);
}
float sum = vecSum.reduceLanes(VectorOperators.ADD);
float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
for (; i < queryVector.length; i++) {
sum += queryVector[i] * vector[i];
xSquare += queryVector[i] * queryVector[i];
ySquare += vector[i] * vector[i];
}
if (ySquare < 1e-8) {
return 0;
}
return (float) (sum / Math.sqrt(xSquare * ySquare));
}
public static float getCosineSimilarityScalarFMA(float[] queryVector, float[] vector) {
float sum = 0;
float xSquare = 0;
float ySquare = 0;
for (int i = 0; i < queryVector.length; i++) {
float qv = queryVector[i];
float v = vector[i];
sum = Math.fma(qv, v, sum);
xSquare = Math.fma(qv, qv, xSquare);
ySquare = Math.fma(v, v, ySquare);
}
if (ySquare < 1e-8) {
return 0;
}
return (float) (sum / Math.sqrt(xSquare * ySquare));
}
public static float getCosineSimilarityScalar(float[] queryVector, float[] vector) {
float sum = 0;
float xSquare = 0;
float ySquare = 0;
for (int i = 0; i < queryVector.length; i++) {
float qv = queryVector[i];
float v = vector[i];
sum += qv * v;
xSquare += qv * qv;
ySquare += v * v;
}
if (ySquare < 1e-8) {
return 0;
}
return (float) (sum / Math.sqrt(xSquare * ySquare));
}
@Benchmark
public float vector_fma() {
return getCosineSimilaritySIMDFMA(V1, V2);
}
@Benchmark
public float vector() {
return getCosineSimilaritySIMD(V1, V2);
}
@Benchmark
public float scalar_fma() {
return getCosineSimilarityScalarFMA(V1, V2);
}
@Benchmark
public float scalar() {
return getCosineSimilarityScalar(V1, V2);
}
}
/*
vectorIntrinsics
--
Benchmark Mode Cnt Score Error Units
VectorSimilarity.scalar thrpt 10 3876.208 ± 43.784 ops/ms
VectorSimilarity.scalar_fma thrpt 10 4511.374 ± 15.802 ops/ms
VectorSimilarity.vector thrpt 10 21819.561 ± 229.776 ops/ms
VectorSimilarity.vector_fma thrpt 10 22866.804 ± 263.219 ops/ms
vector-unstable
--
Benchmark Mode Cnt Score Error Units
VectorSimilarity.scalar thrpt 10 3924.974 ± 17.718 ops/ms
VectorSimilarity.scalar_fma thrpt 10 4473.268 ± 54.080 ops/ms
VectorSimilarity.vector thrpt 10 21511.554 ± 218.139 ops/ms
VectorSimilarity.vector_fma thrpt 10 22756.076 ± 157.729 ops/ms
*/
> On Jul 6, 2020, at 2:17 PM, Vladimir Ivanov <vladimir.x.ivanov at oracle.com> wrote:
>
> Hi Zhuoren,
>
> I haven't investigated what is actually causes the difference, but seeing reduceLanes() calls [1] after the loop I suspect it is caused by inlining issues. (You can verify that by looking at -XX:+PrintInlining output.)
>
> In vectorIntrinsics branch there's a fix integrated recently which makes inlining of vector operations more robust:
> https://hg.openjdk.java.net/panama/dev/rev/5b601a43ac88
>
> Best regards,
> Vladimir Ivanov
>
> [1]
> > float sum = vecSum.reduceLanes(VectorOperators.ADD);
> > float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
> > float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
>
>
> On 06.07.2020 14:42, Wang Zhuo(Zhuoren) wrote:
>> Hi, I am implementing Cosine Distance using Vector API, while I found that performance of my algorithm on vector-unstable is much better than vectorInstrinsics
>> On vectorInstrinsics:
>> normal time used:965
>> vector time used:2529
>> On vector-unstable:
>> normal time used:968
>> vector time used:226
>> The numbers are time (in ms), the smaller the better.
>> I wonder if there are some differences between the two branches that cause this perf difference?
>> The test code is below, please check.
>> Command I used to run:
>> java --add-modules=jdk.incubator.vector -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 VectorSimilarity
>> import jdk.incubator.vector.*;
>> import java.util.Base64;
>> import java.util.concurrent.TimeUnit;
>> public class VectorSimilarity {
>> static final VectorSpecies<Float> SPECIES256 = FloatVector.SPECIES_256;
>> static final VectorSpecies<Float> SPECIES512 = FloatVector.SPECIES_512;
>> private static String x1 = "L5GSwXhHpEH05mNBHnmcQMTw3EBnagFCW1DGQHe/nUFO1B1BlJOpwCBJ9j" +
>> "+RkY1BzqKeQSglN0Gy7krB5CSfQFzxB8Djn5nB2KNFwKcSRMGYzRQ7qMGWQZ0FF0FTceDAIKjxv/zhdkHFZMHB6hU4QZbo2cCAryRB+7OOQCxbfEHRtBlBxPG6P0BYSD+Pgz9BqzOLv/nVO8C9x5/BQOY/wTTIx0GfW1BBGv2lQQwdDcGCqBfB12t/QKUBoEEejIXBPN9kQWsFbEGsGcnBkqJkwKhLgr/IQZxAelAWQfcYpcFQv0HBeiGCQWExhEDrKAnBpAwBQV4bVcFpGNjAyDsNQVOc+0CSc4nBgG/ZQQGRccEXts9BKhYzQNK5+MAlU0DBzPGWwPGRCcEZC5/ADxOcv7lUkEBomM5BuqKiwV2MU8HNGHDBSB84QZRSyMB8RZlBVFdZQXSVgcBTQQBCdWa/QBQ0qkGILUW/6NA9QQnkmsG+5PPBj0UowT6nYD9cwpjAS/w5wTbX2UH8Gb5AR/HUQMTNAMJ9MN9AgHoqPbbUyUFbe47BBHANQWZJBsGBuPlBy94EQADeXsG5eOtBnA+yQCRka8EMcGLBjuoRwb4k7sAasB5Bmk/UwaI1akErp6xBq5G5wNo1E8KHa7tB3IiKQTCffcHphK1BTgJzwVY3JEEip/VAlmgXQSeKCsLEABs/n1/xwL5u58CgQY49ahUWQoAJjj1hhqBASXrrQb6nM0H2fY+/thtbQAQobMAohvXAxM3xv7xyqD+MvpDBrlDiQfBvPcGA8X5AQE4SwXhGx7+uLA1AxY8xu2mVjEE7KlFBArveQFNMtUD3N7DB12BbQcyH4cFhSw3Bu5VWQeTW0z9o03TBxtMlQctp/8E/lLVAGUtTwZsGJMKv/R5A1HKVQV6RhsC1Ji5AcXLFQJd6f0HbB+e+ZDi8wV9tQ0FwCN/B+A89v2DrU0Bcpc5BglTeQH5dT0HePS9Al4XPwdA6YEFlueXAbWKSQSBWzkBy2RnCt9Yawl9b77+xgxBC9eCqQd8f0kFoBG9BVxrkQZh2QkHNW/zBEQiawLJEocDhutTA8zEYwbIvEUIO1T9BmlOTwIhbNEDhrtlAVk9BQARQaj89NQNC6usGwDfQrkBSJrlAON7FQQ8FqsEEc/TAY3zeQYsqUEHV8QPBHJoYQQdn5kGyCiJBlDMYQBBNoUFrxbw/NlmPP3B24j6ChIdBXk2bwdxdDMFQw1rA4hybQXTchr8d9wvBuCbLQSMKmMBH4RpBQIXePa5DT8IjgvtBgAetQZgGgMEprc1BAOeSPJ5XpEEMa0NBgX4uwX7XIsG2Ie0688iqQSpJPsCAy9LBAGHkPw=="; private static String x2 = "5R3ZwGPrxEFMKyNBLFSeQdYav0BQtDFCur7WQAgRYEGHFYC/MKZtvkiFUT+RNXfBVsGBP2KWSUCmAUTBIf+EQG57kMCtXo7BV1DuwLd98r+YzRQ7qKXNwBMSPUFNQffBPrxeQYw1t7/7JjFAKNaXP+cMSEG6GI5BuEx0wUANDMEvDqdAT9YEworQTEEiVBZBiMejQP7t67+iRwzB3HadQB1be0Ei5g5BMt+cQXvYTUHwZsLAuoy3QfrR6EFrIiHB5X8Dwc8XbUH8Yr8/AvGEwa5GkUH3F5tAP8YJQTiDyz+gKsRAFl/rwDxJuUAPyyxBvg2gQU6bjMEPEa7Bz6wYQpQy7MDF5LvB8HP+QCJdicHQDjpC6RpWQcGeY8FMK6vBoeUjQcPYmUG2QmRBBI0nwScESsGMAcxBvRmawRL2A8IByKNAgTQBQuxdDEGq8JBBHJWmQSBDfz8sLe9BE3gFwTdCPkHEaMxBhX8Xwe7BCcE/783Bt6EHwdpbpkHc5L/BCPzRwUdIQUEd/k3AoGNcQQwNmMEyuKRBtnWlwdCBAUI5Y5DBwOZYvdI+MsEu/ixBnpMrwRtYt8FECytC6JjEQW3RHcBtfn3B+sgQQcyQKcEI5ytByvw2wPZdaUH+aqLAQFQ+QPi4REBF/9lBCvJNQTdlEcIAMbzBtD+hwZWufsAEjus/YRyjwR1YuMHj0ZhBa4w+QORAhMEq9qdB/L8JQrjhyUAJBeBAKqoIQUnAq0GsLFdBkfrvQHc1zMHH6THBeggSwaJIOsAawwBBDDWqPwrAlkBYDqe/maUcQabhwsFF2VBBxY8xu5aMQUFDkHVBKhRRwHhgWsEA5jXBlh9NQVMaT0CWlhTAroaFQRyciUHQlp7BF4trQa8unsE4TfI+9XLJQDNpLcIXLZdAuX2MwShiTsFcQh5BrHMqQVI1+UBWe4fBAzi0wfe11UFAIjq9Y1iAQDxrTsEY6plB/JiXQfjFwkHkYGRBVNOhwCMxtEFbqZTA378WQeA/Sb+FrSXCqlYywtb5SsDcqlZBk1EtQZ/RREHZIxG/kcv8QekDIkHPsDXCBL4VQHN8CMGtNvvAC3YwweUuAkKkJCnANEtVQG9z/0DrwyTBQ9hnwWX3kMEdLB1CvIlKwQ0IO0HK1ErBvdRQQVpjMMCJDI/Bb4X8QYVipEGpG2nBeLGUvmBlBT7ISgRB4iGAQUunkkFDFLm/HNaqPzKTVkCITJG/XzlYwbj0XcGD60PBbpLwQbvrs8Az8RXB4ubxQXh/HEDtXLU/kONrwVBs4MGc2X1BJaHkQd0ByEAKXLJBTq7JwPPkJUGJIIRBlh57wX3FjcC2Ie060Qc6Qal5xcCfqQrCl7edQQ==";
>> static float[] v1 = parseBase64ToVector(x1);
>> static float[] v2 = parseBase64ToVector(x2);
>> public static float[] parseArray(byte[] input) {
>> if (input == null) {
>> return null;
>> }
>> float[] floatArr = new float[input.length / 4];
>> for (int i = 0; i < floatArr.length; i++) {
>> int l;
>> l = input[i << 2];
>> l &= 0xff;
>> l |= ((long) input[(i << 2) + 1] << 8);
>> l &= 0xffff;
>> l |= ((long) input[(i << 2) + 2] << 16);
>> l &= 0xffffff;
>> l |= ((long) input[(i << 2) + 3] << 24);
>> floatArr[i] = Float.intBitsToFloat(l);
>> }
>> return floatArr;
>> }
>> public static float[] parseBase64ToVector(String vectorBase64) {
>> return parseArray(Base64.getDecoder().decode(vectorBase64));
>> } public static float getCosineSimilaritySIMD(float[] queryVector, float[] vector) {
>> FloatVector vecX, vecY, vecSum, xSquareV, ySquareV;
>> vecSum = FloatVector.zero(SPECIES256);
>> xSquareV = FloatVector.zero(SPECIES256);
>> ySquareV = FloatVector.zero(SPECIES256);;
>> int i= 0;
>> for (i = 0; i + (SPECIES256.length()) <= queryVector.length; i += SPECIES256.length()) {
>> vecX = FloatVector.fromArray(SPECIES256, queryVector, i);
>> vecY = FloatVector.fromArray(SPECIES256, vector, i);
>> vecSum = vecX.mul(vecY).add(vecSum);
>> vecSum = vecX.fma(vecY, vecSum);
>> xSquareV = vecX.fma(vecX, xSquareV);
>> ySquareV = vecY.fma(vecY, ySquareV);
>> xSquareV = vecX.mul(vecX).add(xSquareV);
>> ySquareV = vecY.mul(vecY).add(ySquareV);
>> vecX.intoArray(vector, i);
>> vecY.intoArray(queryVector, i);
>> }
>> float sum = vecSum.reduceLanes(VectorOperators.ADD);
>> float xSquare = xSquareV.reduceLanes(VectorOperators.ADD);
>> float ySquare = ySquareV.reduceLanes(VectorOperators.ADD);
>> for (; i < queryVector.length; i++) {
>> sum += queryVector[i] * vector[i];
>> xSquare += queryVector[i] * queryVector[i];
>> ySquare += vector[i] * vector[i];
>> }
>> if (ySquare < 1e-8) {
>> return 0;
>> }
>> return (float)(sum / Math.sqrt(xSquare * ySquare));
>> } public static float getCosineSimilarityScalar(float[] queryVector, float[] vector) {
>> float sum = 0;
>> float xSquare = 0;
>> float ySquare = 0;
>> for (int i = 0; i < queryVector.length; i++) {
>> //queryVector[i] = vector[i];
>> sum += (float)(queryVector[i] * vector[i]);
>> xSquare += (float)(queryVector[i] * queryVector[i]);
>> ySquare += (float)(vector[i] * vector[i]);
>> }
>> if (ySquare < 1e-8) {
>> return 0;
>> }
>> return (float)(sum / Math.sqrt(xSquare * ySquare));
>> }
>> public static void main(String[] args) {
>> long t1, t2;
>> for (int i = 0; i < 100000; i++) {
>> getCosineSimilaritySIMD(v1, v2);
>> getCosineSimilarityScalar(v1, v2);
>> }
>> System.out.println("normal result " + getCosineSimilarityScalar(v1, v2) + " vec result " + getCosineSimilaritySIMD(v1, v2));
>> t1 = System.currentTimeMillis();
>> for (int i = 0; i < 2000000; i++) {
>> getCosineSimilarityScalar(v1, v2);
>> }
>> System.out.println("normal time used:" + (System.currentTimeMillis() - t1));
>> t2 = System.currentTimeMillis();
>> for (int i = 0; i < 2000000; i++) {
>> getCosineSimilaritySIMD(v1, v2);
>> }
>> System.out.println("vector time used:" + (System.currentTimeMillis() - t2));
>> }
>> }
>> Regards,
>> Zhuoren
More information about the panama-dev
mailing list