GCM performance and Unsafe byte array accesses

Andrew Haley aph at redhat.com
Thu Aug 27 15:51:45 UTC 2015


I've been looking at the performance of AES/GCM.  The profile is quite
surprising:

samples  cum. samples  %        cum. %  symbol name
476009   476009        36.7033  36.7033 aescrypt_encryptBlock
297239   773248        22.9190  59.6224 ghash_processBlocks
195334   968582        15.0615  74.6839 int com.sun.crypto.provider.GCTR.doFinal(byte[], int, int, byte[], int)

I would have expected aescrypt_encryptBlock and ghash_processBlocks to
be very high, but that GCTR.doFinal is so high is rather surprising:
all it has to do is increment a counter, call aescrypt_encryptBlock,
and xor the result with the plaintext.

The problem seems to be due to byte accesses in GCTR.doFinal() and
GaloisCounterMode.update().  Earlier this year I wrote a bunch of new
Unsafe.get/put-XX-Unaligned methods, and these are ideal for bulk
accesses to byte arrays.  So, as an experiment I wrote some methods to
do array accesses and used them to speed up GCM, with this result:

492274   492274        40.8856  40.8856    13256.jo                 aescrypt_encryptBlock
298185   790459        24.7656  65.6512    13256.jo                 ghash_processBlocks
86325    876784         7.1697  72.8209    13256.jo                 int com.sun.crypto.provider.GCTR.update(byte[], int, int, byte[], int)

GCTR.update() is twice as fast as it was, and overall the performance
of AES/GCM is 10% better.

The changes to the GCM code are quite minor:

diff -r 6940407d544a src/java.base/share/classes/com/sun/crypto/provider/GCTR.java
--- a/src/java.base/share/classes/com/sun/crypto/provider/GCTR.java     Thu Aug 20 07:36:37 2015 -0700
+++ b/src/java.base/share/classes/com/sun/crypto/provider/GCTR.java     Thu Aug 27 16:17:25 2015 +0100
@@ -94,11 +97,12 @@
         int numOfCompleteBlocks = inLen / AES_BLOCK_SIZE;
         for (int i = 0; i < numOfCompleteBlocks; i++) {
             aes.encryptBlock(counter, 0, encryptedCntr, 0);
-            for (int n = 0; n < AES_BLOCK_SIZE; n++) {
-                int index = (i * AES_BLOCK_SIZE + n);
-                out[outOfs + index] =
-                    (byte) ((in[inOfs + index] ^ encryptedCntr[n]));
-            }
+            int index = i * AES_BLOCK_SIZE;
+            ByteArrays.putLongs(out, outOfs + index,
+                                (ByteArrays.getLong(in, inOfs + index + 0) ^
+                                 ByteArrays.getLong(encryptedCntr, 0)),
+                                (ByteArrays.getLong(in, inOfs + index + 8) ^
+                                 ByteArrays.getLong(encryptedCntr, 8)));
             GaloisCounterMode.increment32(counter);
         }
         return inLen;
diff -r 6940407d544a src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java
--- a/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java        Thu Aug 20 07:36:37 2015 -0700
+++ b/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java        Thu Aug 27 16:17:25 2015 +0100
@@ -82,11 +82,8 @@
             // should never happen
             throw new ProviderException("Illegal counter block length");
         }
-        // start from last byte and only go over 4 bytes, i.e. total 32 bits
-        int n = value.length - 1;
-        while ((n >= value.length - 4) && (++value[n] == 0)) {
-            n--;
-        }
+        int counter = ByteArrays.getInt(value, value.length - 4, true);
+        ByteArrays.putInt(value, value.length - 4, counter + 1, true);
     }

     // ivLen in bits

I've attached the full diff.

So, here's my question: there are many places over the crypto code
base where we could take advantage of this.  Do you think it makes
sense to make changes like this?  I can't see any major disadvantages,
and it's a considerable performance improvement.

Andrew.


-------------- next part --------------
diff -r 6940407d544a src/java.base/share/classes/com/sun/crypto/provider/ByteArrays.java
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.base/share/classes/com/sun/crypto/provider/ByteArrays.java	Thu Aug 27 16:47:27 2015 +0100
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package com.sun.crypto.provider;
+
+import sun.misc.Unsafe;
+
+final class ByteArrays {
+
+    private static final Unsafe UNSAFE = Unsafe.getUnsafe();
+
+    static final private int LONG_SIZE = 8;
+    static final private int INT_SIZE = 4;
+    static final private int SHORT_SIZE = 2;
+
+    static final private long BASE_OFFSET = Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+    static final private void checkIndex(int i, int nb, int limit) {
+        if ((i < 0) || (nb > limit - i))
+            throw new ArrayIndexOutOfBoundsException(i);
+    }
+
+    static long getLong(byte[] a, int index) {
+        checkIndex(index, LONG_SIZE, a.length);
+        return UNSAFE.getLongUnaligned(a, BASE_OFFSET + index);
+    }
+
+    static int getInt(byte[] a, int index) {
+        checkIndex(index, INT_SIZE, a.length);
+        return UNSAFE.getIntUnaligned(a, BASE_OFFSET + index);
+    }
+
+    static short getShort(byte[] a, int index) {
+        checkIndex(index, SHORT_SIZE, a.length);
+        return UNSAFE.getShortUnaligned(a, BASE_OFFSET + index);
+    }
+
+    static void putLong(byte[] a, int index, long value) {
+	checkIndex(index, LONG_SIZE, a.length);
+        UNSAFE.putLongUnaligned(a, BASE_OFFSET + index, value);
+    }
+
+    static void putLongs(byte[] a, int index, long value0, long value1) {
+	checkIndex(index, LONG_SIZE * 2, a.length);
+        UNSAFE.putLongUnaligned(a, BASE_OFFSET + index, value0);
+        UNSAFE.putLongUnaligned(a, BASE_OFFSET + index + LONG_SIZE, value1);
+    }
+
+    static void putInt(byte[] a, int index, int value) {
+	checkIndex(index, INT_SIZE, a.length);
+        UNSAFE.putIntUnaligned(a, BASE_OFFSET + index, value);
+    }
+
+    static void putShort(byte[] a, int index, short value) {
+	checkIndex(index, SHORT_SIZE, a.length);
+        UNSAFE.putShortUnaligned(a, BASE_OFFSET + index, value);
+    }
+
+    static long getLong(byte[] a, int index, boolean bigEndian) {
+	checkIndex(index, LONG_SIZE, a.length);
+        return UNSAFE.getLongUnaligned(a, BASE_OFFSET + index, bigEndian);
+    }
+
+    static int getInt(byte[] a, int index, boolean bigEndian) {
+	checkIndex(index, INT_SIZE, a.length);
+        return UNSAFE.getIntUnaligned(a, BASE_OFFSET + index, bigEndian);
+    }
+
+    static short getShort(byte[] a, int index, boolean bigEndian) {
+	checkIndex(index, SHORT_SIZE, a.length);
+        return UNSAFE.getShortUnaligned(a, BASE_OFFSET + index, bigEndian);
+    }
+
+    static void putLong(byte[] a, int index, long value, boolean bigEndian) {
+	checkIndex(index, LONG_SIZE, a.length);
+        UNSAFE.putLongUnaligned(a, BASE_OFFSET + index, value, bigEndian);
+    }
+
+    static void putInt(byte[] a, int index, int value, boolean bigEndian) {
+	checkIndex(index, INT_SIZE, a.length);
+        UNSAFE.putIntUnaligned(a, BASE_OFFSET + index, value, bigEndian);
+    }
+
+    static void putShort(byte[] a, int index, short value, boolean bigEndian) {
+	checkIndex(index, SHORT_SIZE, a.length);
+        UNSAFE.putShortUnaligned(a, BASE_OFFSET + index, value, bigEndian);
+    }
+}
\ No newline at end of file
diff -r 6940407d544a src/java.base/share/classes/com/sun/crypto/provider/GCTR.java
--- a/src/java.base/share/classes/com/sun/crypto/provider/GCTR.java	Thu Aug 20 07:36:37 2015 -0700
+++ b/src/java.base/share/classes/com/sun/crypto/provider/GCTR.java	Thu Aug 27 16:47:27 2015 +0100
@@ -32,6 +32,7 @@
 import java.security.*;
 import javax.crypto.*;
 import static com.sun.crypto.provider.AESConstants.AES_BLOCK_SIZE;
+import sun.misc.Unsafe;
 
 /**
  * This class represents the GCTR function defined in NIST 800-38D
@@ -66,6 +67,8 @@
     // needed for save/restore calls
     private byte[] counterSave = null;
 
+    private static final Unsafe UNSAFE = Unsafe.getUnsafe();
+
     // NOTE: cipher should already be initialized
     GCTR(SymmetricCipher cipher, byte[] initialCounterBlk) {
         this.aes = cipher;
@@ -94,11 +97,12 @@
         int numOfCompleteBlocks = inLen / AES_BLOCK_SIZE;
         for (int i = 0; i < numOfCompleteBlocks; i++) {
             aes.encryptBlock(counter, 0, encryptedCntr, 0);
-            for (int n = 0; n < AES_BLOCK_SIZE; n++) {
-                int index = (i * AES_BLOCK_SIZE + n);
-                out[outOfs + index] =
-                    (byte) ((in[inOfs + index] ^ encryptedCntr[n]));
-            }
+            int index = i * AES_BLOCK_SIZE;
+            ByteArrays.putLongs(out, outOfs + index,
+                                (ByteArrays.getLong(in, inOfs + index + 0) ^
+                                 ByteArrays.getLong(encryptedCntr, 0)),
+                                (ByteArrays.getLong(in, inOfs + index + 8) ^
+                                 ByteArrays.getLong(encryptedCntr, 8)));
             GaloisCounterMode.increment32(counter);
         }
         return inLen;
diff -r 6940407d544a src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java
--- a/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java	Thu Aug 20 07:36:37 2015 -0700
+++ b/src/java.base/share/classes/com/sun/crypto/provider/GaloisCounterMode.java	Thu Aug 27 16:47:27 2015 +0100
@@ -82,11 +82,8 @@
             // should never happen
             throw new ProviderException("Illegal counter block length");
         }
-        // start from last byte and only go over 4 bytes, i.e. total 32 bits
-        int n = value.length - 1;
-        while ((n >= value.length - 4) && (++value[n] == 0)) {
-            n--;
-        }
+        int counter = ByteArrays.getInt(value, value.length - 4, true);
+        ByteArrays.putInt(value, value.length - 4, counter + 1, true);
     }
 
     // ivLen in bits


More information about the security-dev mailing list