RFR: 8263006: Add optimization for Max(*)Node and Min(*)Node [v2]

Wang Huang whuang at openjdk.java.net
Sat Apr 17 06:42:40 UTC 2021


On Fri, 16 Apr 2021 01:40:01 GMT, Wang Huang <whuang at openjdk.org> wrote:

>> * I optimize `max` and `min` by using these identities 
>>     - op (max(a,b) , min(a,b))=== op(a,b)
>>     - if op is commutable
>>     - example : 
>>       - max(a,b) + min(a,b))=== a + b // op = add
>>       - max(a,b) * min(a,b))=== a * b  // op = mul
>>       -  max( max(a,b) , min(a,b)))=== max(a,b) // op = max()
>>       - min( max(a,b) , min(a,b)))=== max(a,b) // op = min()
>> * Test case 
>>   ```java
>>   /*
>>    * Copyright (c) 2021, Huawei Technologies Co. Ltd. All rights reserved.
>>    * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
>>    *
>>    * This code is free software; you can redistribute it and/or modify it
>>    * under the terms of the GNU General Public License version 2 only, as
>>    * published by the Free Software Foundation.
>>    *
>>    * This code is distributed in the hope that it will be useful, but WITHOUT
>>    * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
>>    * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
>>    * version 2 for more details (a copy is included in the LICENSE file that
>>    * accompanied this code).
>>    *
>>    * You should have received a copy of the GNU General Public License version
>>    * 2 along with this work; if not, write to the Free Software Foundation,
>>    * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
>>    *
>>    * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
>>    * or visit www.oracle.com if you need additional information or have any
>>    * questions.
>>    */
>>   package org.sample;
>>   
>>   import org.openjdk.jmh.annotations.Benchmark;
>>   import org.openjdk.jmh.annotations.*;
>>   
>>   import java.util.Random;
>>   import java.util.concurrent.TimeUnit;
>>   import org.openjdk.jmh.infra.Blackhole;
>>   
>>   @BenchmarkMode({Mode.AverageTime})
>>   @OutputTimeUnit(TimeUnit.MICROSECONDS)
>>   public class MyBenchmark {
>>   
>>       static int length = 100000;
>>       static double[] data1 = new double[length];
>>       static double[] data2 = new double[length];
>>       static Random random = new Random();
>>   
>>       static {
>>           for(int i = 0; i < length; ++i) {
>>             data1[i] = random.nextDouble();
>>             data2[i] = random.nextDouble();
>>           }
>>       }
>>   
>>       @Benchmark
>>       public void testAdd(Blackhole bh) {
>>         double sum = 0;
>>         for (int i = 0; i < length; i++) {
>>             sum += Math.max(data1[i], data2[i]) + Math.min(data1[i], data2[i]);
>>         }
>>         bh.consume(sum);
>>       }
>>   
>>       @Benchmark
>>       public void testMax(Blackhole bh) {
>>           double sum = 0;
>>           for (int i = 0; i < length; i++) {
>>               sum += Math.max(Math.max(data1[i], data2[i]), Math.min(data1[i], data2[i]));
>>           }
>>           bh.consume(sum);
>>       }
>>   
>>       @Benchmark
>>       public void testMin(Blackhole bh) {
>>           double sum = 0;
>>           for (int i = 0; i < length; i++) {
>>               sum += Math.min(Math.max(data1[i], data2[i]), Math.min(data1[i], data2[i]));
>>           }
>>           bh.consume(sum);
>>       }
>>   
>>       @Benchmark
>>       public void testMul(Blackhole bh) {
>>           double sum = 0;
>>           for (int i = 0; i < length; i++) {
>>               sum += (Math.max(data1[i], data2[i]) * Math.min(data1[i], data2[i]));
>>           }
>>           bh.consume(sum);
>>       }
>>   }
>>   ```
>> 
>> *  The result is listed here (aarch64):
>> 
>>   before:
>> 
>>   |Benchmark|                           Mode|  Samples|    Score|  Score error| Units|
>>   |---|                           ---|  ---|    ---|  --- | ---|
>>    |o.s.MyBenchmark.testAdd     |avgt     |   10  | 556.048     |   32.368       |  us/op |
>>  |  o.s.MyBenchmark.testMax  |   avgt     |   10   |543.065    |    54.221    |     us/op |
>>  |  o.s.MyBenchmark.testMin    | avgt        |10   |570.731 |       37.630   |      us/op |
>>   | o.s.MyBenchmark.testMul   |  avgt    |    10 |  531.906     |   20.518    |     us/op |
>>  
>>   after:
>> 
>>    |Benchmark|                           Mode|  Samples|    Score|  Score error| Units|
>>   |---|                           ---|  ---|    ---|  --- | ---|
>>    |  o.s.MyBenchmark.testAdd |      avgt     |     10   |  319.350  |         9.248     |      us/op |  
>>  |    o.s.MyBenchmark.testMax     |  avgt    |      10 |    356.138      |    10.736 |          us/op |  
>>  |    o.s.MyBenchmark.testMin  |     avgt      |    10 |    323.731  |        16.621     |      us/op |  
>>  |    o.s.MyBenchmark.testMul    |   avgt     |     10  |   338.458      |    23.755  |        us/op |
>> 
>> *  I have tested `NaN`  ` INFINITY` and `-INFINITY` and got same result (before/after)
>
> Wang Huang has updated the pull request incrementally with one additional commit since the last revision:
> 
>   adjust code style

Thank you for your review.
> Do you have a real example in Java applications which benefit from this optimization?
> We should not add and **support** code which would never be used in real world.
> 
Yes. We refined this optimization from our internal software experience. For instance, the model `min( max(a,b) , min(a,b)))` exists in many source codes in  some AI projects.
> Optimization will not work for Integer because of `_min` and `_max` intrinsic which generates `cmove`:
> https://github.com/openjdk/jdk/blob/master/src/hotspot/share/opto/library_call.cpp#L1806
> 
Yes. Adding `MaxINode`'s `max_opcode`  is just for `max_opcode` method is abstract. Our test cases is for float types only. 
> I am not sure if this optimization will always work for float/double because of NaN values.
> 
> You need to verify results for all edge cases.

I have tested that and showed in my comments. The test cases for NaN values and other special values are listed here 

import java.lang.Math;

public class Test {

    public static void main(String[] args) throws Exception {
      Test m = new Test();
      m.test();
    }

    public void test() throws Exception {
      double[] num = new double[9];
      num[0] = 1; num[1] = 0; num[2] = -0;
      num[3] = Double.POSITIVE_INFINITY; 
      num[4] = Double.NEGATIVE_INFINITY;
      num[5] = Double.NaN;
      num[6] = Double.MAX_VALUE;
      num[7] = Double.MIN_VALUE;
      num[8] = Double.MIN_NORMAL;

      for(int i = 0; i < 9; i++) {
        for(int j = 0; j < 9; j++) {
          check(add_opt(num[i], num[j]), (num[i] + num[j]));
          check(mul_opt(num[i], num[j]), (num[i] * num[j]));
          check(max_opt(num[i], num[j]), Math.max(num[i], num[j]));
          check(min_opt(num[i], num[j]), Math.min(num[i], num[j]));
        } 
      }
    }

    public void check(double a, double b) {
      if (a != b) {
        System.out.println("false");
        System.out.println(a);
        System.out.println(b);
        System.out.println();
      }
    }

    public double add_opt(double a, double b) throws Exception {
      return Math.max(a, b) + Math.min(a, b);
    }

    public double mul_opt(double a, double b) throws Exception {
      return Math.max(a, b) * Math.min(a, b);
    }

    public double max_opt(double a, double b) throws Exception {
      return Math.max(Math.max(a, b), Math.min(a, b));
    }

    public double min_opt(double a, double b) throws Exception {
      return Math.min(Math.max(a, b), Math.min(a, b));
    }
}

The `NaN` is a special case. Because `NaN == NaN` is false in Java, so I run the case and check the result. 

Should I add the other test cases for `NaN` ?

-------------

PR: https://git.openjdk.java.net/jdk/pull/3513


More information about the hotspot-compiler-dev mailing list