toString/equals/hashCode implemented using method handles

Remi Forax forax at univ-mlv.fr
Wed Jun 20 21:00:23 UTC 2018


As discussed during our today meeting,
here is an implementation of toString, equals and hashCode that can be used as default implementation of these methods for value types.

toString use the ConcatMetafactory so the code is simple,
for equals(), it first sort the fields to have the primitive type at the end of the array and the reference type at the beginning,
given that the tree is constructed by adding equals tests in front of the previous equals tests, the constructed method handle tree will test the fields that stores primitive types first,
for hashCode(), it's just a reduction using foldArguments, it creates big mh trees but the code should be fast (not loop combinator).  

The code tries to cache every primitive method handles and doesn't cache method handles that come from methods in user-defined classes.

Currently the code retrieves the fields using the reflection, making the code hard to test because the order of the fields returned by getDeclaredFields() is not specified,
perhaps the list of the fields should be stored in a Constant Dynamic as an array of MethodHandle (we can use MethodHandleInfo if we want the name of a method) and pass as a boostrap argument of the indy call inside toString, equals and hashCode.

cheers,
Rémi

---
package fr.umlv.valuetype;

import static java.lang.invoke.MethodHandles.constant;
import static java.lang.invoke.MethodHandles.dropArguments;
import static java.lang.invoke.MethodHandles.filterArguments;
import static java.lang.invoke.MethodHandles.filterReturnValue;
import static java.lang.invoke.MethodHandles.foldArguments;
import static java.lang.invoke.MethodHandles.guardWithTest;
import static java.lang.invoke.MethodHandles.insertArguments;
import static java.lang.invoke.MethodHandles.lookup;
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.lang.invoke.MethodHandles.publicLookup;
import static java.lang.invoke.MethodHandles.zero;
import static java.lang.invoke.MethodType.methodType;
import static java.lang.invoke.StringConcatFactory.makeConcatWithConstants;

import java.lang.invoke.CallSite;
import java.lang.invoke.ConstantCallSite;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles.Lookup;
import java.lang.invoke.MethodType;
import java.lang.invoke.StringConcatException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Objects;
import java.util.stream.IntStream;

public class ValueTypeBSM {
  public static CallSite makeBootstrapMethod(Lookup lookup, String name, MethodType type) {
    Field[] fields = Arrays.stream(lookup.lookupClass().getDeclaredFields())
        .filter(field -> !Modifier.isStatic(field.getModifiers()))
        .toArray(Field[]::new);
    return createCallSite(lookup, name, fields);
  }
  
  private static CallSite createCallSite(Lookup lookup, String name, Field[] fields) {
    switch(name) {
    case "toString":
      return createToStringCallSite(lookup, fields);
    case "equals":
      return createEqualsCallSite(lookup, fields);
    case "hashCode":
      return createHashCodeCallSite(lookup, fields);
    default:
      throw new LinkageError("unknown method " + name);
    }
  }

  private static MethodHandle[] findGetters(Lookup lookup, Field[] fields) {
    MethodHandle[] mhs = new MethodHandle[fields.length];
    for(int i = 0; i <mhs.length; i++) {
      try {
        mhs[i] = lookup.unreflectGetter(fields[i]);
      } catch (IllegalAccessException e) {
        throw newLinkageError(e);
      }
    }
    return mhs;
  }
  
  private static final int PRIMITIVE_COUNT = 8;
  
  static int projection(Class<?> type) {
    switch(type.getName()) {
    case "boolean":
      return 0;
    case "byte":
      return 1;
    case "char":
      return 2;
    case "short":
      return 3;
    case "int":
      return 4;
    case "long":
      return 5;
    case "float":
      return 6;
    case "double":
      return 7;
    default:
      throw new AssertionError();
    }
  }
  
  static LinkageError newLinkageError(Throwable e) {
    return (LinkageError)new LinkageError().initCause(e);
  }
  
  static class ArrayToString {
    // the code that access this array is racy but this is a cache, so that's not an issue
    private static final MethodHandle[] ADAPTER_CACHE = new MethodHandle[PRIMITIVE_COUNT + 1];
    private static final int OBJECT_INDEX = PRIMITIVE_COUNT;
    
    static MethodHandle adapter(Class<?> arrayType) {
      Class<?> componentType = arrayType.getComponentType();
      int index = componentType.isPrimitive()? projection(componentType): OBJECT_INDEX;
      MethodHandle mh = ADAPTER_CACHE[index];
      if (mh != null) {
        return mh;
      }
      Class<?> erasedType = componentType.isPrimitive()? arrayType: Object[].class;
      try {
        mh = publicLookup().findStatic(Arrays.class, "toString", methodType(String.class, erasedType));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }

      MethodHandle concurrentMh = ADAPTER_CACHE[index];
      if (concurrentMh != null) {
        return concurrentMh;
      }
      ADAPTER_CACHE[index] = mh;
      return mh;
    }
  }
  
  private static CallSite createToStringCallSite(Lookup lookup, Field[] fields) {
    int length = fields.length;
    StringBuilder format = new StringBuilder();
    String separator = "";
    Class<?>[] parameterTypes = new Class<?>[length];
    MethodHandle[] getters = new MethodHandle[length];
    for(int i = 0; i < length; i++) {
      Field field = fields[i];
      format.append(separator).append(field.getName()).append("=\1");
      separator = " ";
      
      MethodHandle getter;
      try {
        getter = lookup.unreflectGetter(field);
      } catch (IllegalAccessException e) {
        throw newLinkageError(e);
      }
      Class<?> type = field.getType();
      if (type.isArray()) {
        MethodHandle adapter = ArrayToString.adapter(type).asType(methodType(String.class, type));
        getter = filterReturnValue(getter, adapter);
        type = String.class;
      }
      getters[i] = getter;
      parameterTypes[i] = type;
    }
    
    // ask for a MethodHandle that will do the concatenation
    MethodHandle target;
    try {
      target = makeConcatWithConstants(lookup, "toString", methodType(String.class, parameterTypes), format.toString()).dynamicInvoker();
    } catch (StringConcatException e) {
      throw newLinkageError(e);
    }
    
    // apply all getters
    target = filterArguments(target, 0, getters);
    
    // duplicate the first argument (this)
    target = permuteArguments(target, methodType(String.class, lookup.lookupClass()), new int[length]);
    
    return new ConstantCallSite(target);
  }
  
  static class Equality {
    @SuppressWarnings("unused")
    private static boolean same(boolean b1, boolean b2) {
      return b1 == b2;
    }
    @SuppressWarnings("unused")
    private static boolean same(byte b1, byte b2) {
      return b1 == b2;
    }
    @SuppressWarnings("unused")
    private static boolean same(short s1, short s2) {
      return s1 == s2;
    }
    @SuppressWarnings("unused")
    private static boolean same(char c1, char c2) {
      return c1 == c2;
    }
    @SuppressWarnings("unused")
    private static boolean same(int i1, int i2) {
      return i1 == i2;
    }
    @SuppressWarnings("unused")
    private static boolean same(long l1, long l2) {
      return l1 == l2;
    }
    @SuppressWarnings("unused")
    private static boolean same(float f1, float f2) {
      return f1 == f2;
    }
    @SuppressWarnings("unused")
    private static boolean same(double d1, double d2) {
      return d1 == d2;
    }
    @SuppressWarnings("unused")
    private static boolean same(Object o1, Object o2) {
      return o1 == o2;
    }
    
    // the code that access this array is racy but this is a cache, so that's not an issue
    private static final MethodHandle[] SAME_CACHE = new MethodHandle[PRIMITIVE_COUNT]; 
    
    private static MethodHandle primitiveEquals(Class<?> primitiveType) {
      int index = projection(primitiveType);
      MethodHandle mh = SAME_CACHE[index];
      if (mh != null) {
        return mh;
      }
      mh = findSameMH(primitiveType);
      MethodHandle concurrentMh = SAME_CACHE[index];
      if (concurrentMh != null) {
        return concurrentMh;
      }
      SAME_CACHE[index] = mh;
      return mh;
    }
    private static MethodHandle findSameMH(Class<?> type) {
      try {
        return lookup().findStatic(Equality.class, "same", methodType(boolean.class, type, type));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }
    }
    
    private static final MethodHandle SAME_OBJECT, NULL_CHECK, TRUE, FALSE, IS_INSTANCE;
    static {
      MethodHandle mh = findSameMH(Object.class);
      SAME_OBJECT = mh;
      NULL_CHECK = dropArguments(insertArguments(mh, 1, (Object)null), 1, Object.class);
      TRUE = dropArguments(constant(boolean.class, true), 0, Object.class, Object.class);
      FALSE = dropArguments(constant(boolean.class, false), 0, Object.class, Object.class);
      
      try {
        IS_INSTANCE = publicLookup().findVirtual(Class.class, "isInstance", methodType(boolean.class, Object.class));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }
    }
    
    private static MethodHandle objectEquals(Lookup lookup, Class<?> type)  {
      MethodHandle equals;
      try {
        equals = lookup.findVirtual(type, "equals", methodType(boolean.class, Object.class));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }
      
      // equivalent to (a == b)? true: (a == null)? false: a.equals(b)
      return guardWithTest(SAME_OBJECT.asType(methodType(boolean.class, type, type)),
          TRUE.asType(methodType(boolean.class, type, type)),
          guardWithTest(NULL_CHECK.asType(methodType(boolean.class, type, type)),
              FALSE.asType(methodType(boolean.class, type, type)),
              equals.asType(methodType(boolean.class, type, type))));
    }
    
    private static MethodHandle equals(Lookup lookup, Class<?> type) {
      return type.isPrimitive()? primitiveEquals(type): objectEquals(lookup, type);
    }
    
    private static MethodHandle equalsAll(Lookup lookup, Class<?> declaredType, MethodHandle[] getters) {
      MethodHandle _false = FALSE.asType(methodType(boolean.class, declaredType, declaredType));
      
      MethodHandle target = TRUE.asType(methodType(boolean.class, declaredType, declaredType));
      for(MethodHandle getter: getters) {
        MethodHandle test = filterArguments(equals(lookup, getter.type().returnType()), 0, getter, getter);
        target = guardWithTest(test, target, _false);
      }
      return target;
    }
    
    static MethodHandle createEquals(Lookup lookup, MethodHandle[] getters) {
      Class<?> declaredType = lookup.lookupClass();
      MethodHandle test = dropArguments(IS_INSTANCE.bindTo(declaredType), 0, declaredType);
      return guardWithTest(test,
          equalsAll(lookup, declaredType, getters).asType(methodType(boolean.class, declaredType, Object.class)),
          FALSE.asType(methodType(boolean.class, declaredType, Object.class)));
    }
  }
  
  private static CallSite createEqualsCallSite(Lookup lookup, Field[] fields) {
    // move primitives at the end of the array, so they will be tested first (createEquals creates equals tests from the last to the first)
    Integer[] orders = IntStream.range(0, fields.length).boxed().toArray(Integer[]::new);
    Arrays.sort(orders, (index1, index2) -> {
      Class<?> t1 = fields[index1].getType();
      Class<?> t2 = fields[index2].getType();
      if (t1.isPrimitive()) {
        if (!t2.isPrimitive()) {
          return 1;  
        }
      } else {
        if (t2.isPrimitive()) {
          return -1;
        }
      }
      // for both references and primitives, move them in the array so the first in fields is the last in sortedFields
      return Integer.compare(index2, index1);
    });
    Field[] sortedFields = new Field[fields.length];
    Arrays.setAll(sortedFields, i -> fields[orders[i]]);
    
    return new ConstantCallSite(Equality.createEquals(lookup, findGetters(lookup, sortedFields)));
  }
  
  static class HashCode {
    // the code that access this array is racy but this is a cache, so that's not an issue
    private static final MethodHandle[] HASH_CODE_CACHE = new MethodHandle[PRIMITIVE_COUNT]; 
    private static final Class<?>[] WRAPPERS = new Class<?>[] {
      Boolean.class, Byte.class, Character.class, Short.class, Integer.class, Long.class, Float.class, Double.class
    };
    
    private static Class<?> wrapper(Class<?> primitiveType) {
      return WRAPPERS[projection(primitiveType)];
    }
    
    private static MethodHandle primitiveHashCode(Class<?> primitiveType) {
      int index = projection(primitiveType);
      MethodHandle mh = HASH_CODE_CACHE[index];
      if (mh != null) {
        return mh;
      }
      try {
        mh = publicLookup().findStatic(wrapper(primitiveType), "hashCode", methodType(int.class, primitiveType));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }
      MethodHandle concurrentMh = HASH_CODE_CACHE[index];
      if (concurrentMh != null) {
        return concurrentMh;
      }
      HASH_CODE_CACHE[index] = mh;
      return mh;
    }
    
    private static MethodHandle NULL_CHECK, ZERO, REDUCE;
    static {
      Lookup lookup = lookup();
      try {
        NULL_CHECK = lookup.findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class));
        REDUCE = lookup.findStatic(HashCode.class, "reduce", methodType(int.class, int.class, int.class));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }
      ZERO = dropArguments(zero(int.class), 0, Object.class);
    }
    
    @SuppressWarnings("unused")
    private static int reduce(int value, int accumulator) {
      return value + accumulator * 31;
    }
    
    private static MethodHandle objectHashCode(Lookup lookup, Class<?> type) {
      MethodHandle hashCode;
      try {
        hashCode = lookup.findVirtual(type, "hashCode", methodType(int.class));
      } catch (NoSuchMethodException | IllegalAccessException e) {
        throw newLinkageError(e);
      }
      return guardWithTest(NULL_CHECK.asType(methodType(boolean.class, type)),
          ZERO.asType(methodType(int.class, type)),
          hashCode);
    }
    
    private static MethodHandle hashCode(Lookup lookup, Class<?> type) {
      return type.isPrimitive()? primitiveHashCode(type): objectHashCode(lookup, type);
    }
    
    static MethodHandle hashCodeAll(Lookup lookup, Class<?> declaredType, MethodHandle[] getters) {
      MethodHandle target = dropArguments(constant(int.class, 1), 0, declaredType);
      for(MethodHandle getter: getters) {
        MethodHandle hashField = filterReturnValue(getter, hashCode(lookup, getter.type().returnType()));
        target = foldArguments(
            foldArguments(dropArguments(REDUCE, 2, declaredType), dropArguments(hashField, 0, int.class)),
            target);
      }
      return target;
    }
  }
  
  private static CallSite createHashCodeCallSite(Lookup lookup, Field[] fields) {
    return new ConstantCallSite(HashCode.hashCodeAll(lookup, lookup.lookupClass(), findGetters(lookup, fields)));
  }
}


More information about the valhalla-spec-observers mailing list