New Methods in Java 9: Math.fma and Arrays.mismatch


There are two noteworthy new methods in Java 9: Arrays.mismatch and Math.fma.

Arrays.mismatch

This method takes two primitive arrays, and returns the index of the first differing values. This effectively computes the longest common prefix of the two arrays. This is really quite useful, mostly for text processing but also for Bioinformatics (protein sequencing and so on, much more interesting than the sort of thing I work on). Having worked extensively with Apache HBase (where a vast majority of the API involves manipulating byte arrays) I can think of lots of less interesting use cases for this method.

Looking carefully, you can see that the method calls into the internal ArraysSupport utility class, which will try to perform a vectorised mismatch (an intrinsic candidate). Since this will use AVX instructions, this is very fast; much faster than a handwritten loop.

Let’s measure the boost versus a handwritten loop, testing across a range of common prefices and array lengths of byte[].

    @Benchmark
    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    public void Mismatch_Intrinsic(BytePrefixData data, Blackhole bh) {
        bh.consume(Arrays.mismatch(data.data1, data.data2));
    }


    @Benchmark
    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    public void Mismatch_Handwritten(BytePrefixData data, Blackhole bh) {
        byte[] data1 = data.data1;
        byte[] data2 = data.data2;
        int length = Math.min(data1.length, data2.length);
        int mismatch = -1;
        for (int i = 0; i < length; ++i) {
            if (data1[i] != data2[i]) {
                mismatch = i;
                break;
            }
        }
        bh.consume(mismatch);
    }

The results speak for themselves. Irritatingly, there is some duplication in output because I haven’t figured out how to make JMH use a subset of the Cartesian product of its parameters.

Benchmark (prefix) (size) Mode Cnt Score Error Units
Mismatch_Handwritten 10 100 thrpt 10 22.360 ± 0.938 ops/us
Mismatch_Handwritten 10 1000 thrpt 10 2.459 ± 0.256 ops/us
Mismatch_Handwritten 10 10000 thrpt 10 0.255 ± 0.009 ops/us
Mismatch_Handwritten 100 100 thrpt 10 22.763 ± 0.869 ops/us
Mismatch_Handwritten 100 1000 thrpt 10 2.690 ± 0.044 ops/us
Mismatch_Handwritten 100 10000 thrpt 10 0.273 ± 0.008 ops/us
Mismatch_Handwritten 1000 100 thrpt 10 24.970 ± 0.713 ops/us
Mismatch_Handwritten 1000 1000 thrpt 10 2.791 ± 0.066 ops/us
Mismatch_Handwritten 1000 10000 thrpt 10 0.281 ± 0.007 ops/us
Mismatch_Intrinsic 10 100 thrpt 10 89.169 ± 2.759 ops/us
Mismatch_Intrinsic 10 1000 thrpt 10 26.995 ± 0.501 ops/us
Mismatch_Intrinsic 10 10000 thrpt 10 3.553 ± 0.065 ops/us
Mismatch_Intrinsic 100 100 thrpt 10 83.037 ± 5.590 ops/us
Mismatch_Intrinsic 100 1000 thrpt 10 26.249 ± 0.714 ops/us
Mismatch_Intrinsic 100 10000 thrpt 10 3.523 ± 0.122 ops/us
Mismatch_Intrinsic 1000 100 thrpt 10 87.921 ± 6.566 ops/us
Mismatch_Intrinsic 1000 1000 thrpt 10 25.812 ± 0.442 ops/us
Mismatch_Intrinsic 1000 10000 thrpt 10 4.177 ± 0.059 ops/us

Why is there such a big difference? Look at how the score decreases as a function of array length, even when the common prefix, and therefore the number of comparisons required, is small: clearly the performance of this algorithm depends on the efficiency of memory access. Arrays.mismatch optimises this, reading qwords of the array into SIMD registers. Working one long at a time, it is possible to compute the XOR in a single instruction to determine if it’s even necessary to look at each byte.

java/util/ArraysSupport.vectorizedMismatch(Ljava/lang/Object;JLjava/lang/Object;JII)I  [0x000002bd9215a820, 0x000002bd9215aa78]  600 bytes
Argument 0 is unknown.RIP: 0x2bd9215a820 Code size: 0x00000258
[Entry Point]
[Verified Entry Point]
[Constants]
  # {method} {0x000002bda79cbf68} 'vectorizedMismatch' '(Ljava/lang/Object;JLjava/lang/Object;JII)I' in 'java/util/ArraysSupport'
  # parm0:    rdx:rdx   = 'java/lang/Object'
  # parm1:    r8:r8     = long
  # parm2:    r9:r9     = 'java/lang/Object'
  # parm3:    rdi:rdi   = long
  # parm4:    rsi       = int
  # parm5:    rcx       = int
  #           [sp+0x60]  (sp of caller)
  0x000002bd9215a820: mov     dword ptr [rsp+0ffffffffffff9000h],eax
                                                ;...89
                                                ;...84
                                                ;...24
                                                ;...00
                                                ;...90
                                                ;...ff
                                                ;...ff

  0x000002bd9215a827: push    rbp               ;...55

  0x000002bd9215a828: sub     rsp,50h           ;...48
                                                ;...83
                                                ;...ec
                                                ;...50
                                                ;*synchronization entry
                                                ; - java.util.ArraysSupport::vectorizedMismatch@-1 (line 120)

  0x000002bd9215a82c: mov     r10,rdi           ;...4c
                                                ;...8b
                                                ;...d7

  0x000002bd9215a82f: vmovq   xmm2,r9           ;...c4
                                                ;...c1
                                                ;...f9
                                                ;...6e
                                                ;...d1

  0x000002bd9215a834: vmovq   xmm1,rdx          ;...c4
                                                ;...e1
                                                ;...f9
                                                ;...6e
                                                ;...ca

  0x000002bd9215a839: mov     r14d,ecx          ;...44
                                                ;...8b
                                                ;...f1

  0x000002bd9215a83c: vmovd   xmm0,esi          ;...c5
                                                ;...f9
                                                ;...6e
                                                ;...c6

  0x000002bd9215a840: mov     r9d,3h            ;...41
                                                ;...b9
                                                ;...03
                                                ;...00
                                                ;...00
                                                ;...00

  0x000002bd9215a846: sub     r9d,ecx           ;...44
                                                ;...2b
                                                ;...c9
                                                ;*isub {reexecute=0 rethrow=0 return_oop=0}
                                                ; - java.util.ArraysSupport::vectorizedMismatch@5 (line 120)

  0x000002bd9215a849: mov     edx,esi           ;...8b
                                                ;...d6

  0x000002bd9215a84b: mov     ecx,r9d           ;...41
                                                ;...8b
                                                ;...c9

  0x000002bd9215a84e: sar     edx,cl            ;...d3
                                                ;...fa
                                                ;*ishr {reexecute=0 rethrow=0 return_oop=0}
                                                ; - java.util.ArraysSupport::vectorizedMismatch@17 (line 122)

  0x000002bd9215a850: mov     eax,1h            ;...b8
                                                ;...01
                                                ;...00
                                                ;...00
                                                ;...00

  0x000002bd9215a855: xor     edi,edi           ;...33
                                                ;...ff

  0x000002bd9215a857: test    edx,edx           ;...85
                                                ;...d2

  0x000002bd9215a859: jle     2bd9215a97ah      ;...0f
                                                ;...8e
                                                ;...1b
                                                ;...01
                                                ;...00
                                                ;...00

The code for this benchmark is at github.

Math.fma

In comparison to users of some languages, Java programmers are lackadaisical about floating point errors. It’s a good job that historically Java hasn’t been considered suitable for the implementation of numerical algorithms. But all of a sudden there is a revolution of data science on the JVM, albeit mostly driven by the Scala community, with JVM implementations of structures like recurrent neural networks abounding. It matters less for machine learning than root finding, but how accurate can these implementations be without JVM level support for minimising the propagation floating point errors? With Math.fma this is improving, by allowing two common operations to be performed before rounding.

Math.fma fuses a multiplication and an addition into a single floating point operation to compute expressions like ab + c. This has two key benefits:

  1. There’s only one operation, and only one rounding error
  2. This is explicitly supported in AVX2 by the VFMADD* instructions

Newton’s Method

To investigate any superior suppression of floating point errors, I use a toy implementation of Newton’s method to compute the root of a quadratic equation, which any teenager could calculate analytically (the error is easy to quantify).

I compare these two implementations for 4x^2 - 12x + 9 (there is a repeated root at 1.5) to get an idea for the error (defined by |1.5 - x_n|) after a large number of iterations.

I implemented this using FMA:

public class NewtonsMethodFMA {

    private final double[] coefficients;

    public NewtonsMethodFMA(double[] coefficients) {
        this.coefficients = coefficients;
    }


    public double evaluateF(double x) {
        double f = 0D;
        int power = coefficients.length - 1;
        for (int i = 0; i < coefficients.length; ++i) {
            f = Math.fma(coefficients[i], Math.pow(x, power--), f);
        }
        return f;
    }

    public double evaluateDF(double x) {
        double df = 0D;
        int power = coefficients.length - 2;
        for (int i = 0; i < coefficients.length - 1; ++i) {
            df = Math.fma((power + 1) * coefficients[i],  Math.pow(x, power--), df);
        }
        return df;
    }

    public double solve(double initialEstimate, int maxIterations) {
        double result = initialEstimate;
        for (int i = 0; i < maxIterations; ++i) {
            result -= evaluateF(result)/evaluateDF(result);
        }
        return result;
    }
}

And an implementation with normal operations:


public class NewtonsMethod {

    private final double[] coefficients;

    public NewtonsMethod(double[] coefficients) {
        this.coefficients = coefficients;
    }


    public double evaluateF(double x) {
        double f = 0D;
        int power = coefficients.length - 1;
        for (int i = 0; i < coefficients.length; ++i) {
            f += coefficients[i] * Math.pow(x, power--);
        }
        return f;
    }

    public double evaluateDF(double x) {
        double df = 0D;
        int power = coefficients.length - 2;
        for (int i = 0; i < coefficients.length - 1; ++i) {
            df += (power + 1) * coefficients[i] * Math.pow(x, power--);
        }
        return df;
    }

    public double solve(double initialEstimate, int maxIterations) {
        double result = initialEstimate;
        for (int i = 0; i < maxIterations; ++i) {
            result -= evaluateF(result)/evaluateDF(result);
        }
        return result;
    }
}

When I run this code for 1000 iterations, the FMA version results in 1.5000000083575202, whereas the vanilla version results in 1.500000017233207. It’s completely unscientific, but seems plausible and confirms my prejudice so… In fact, it’s not that simple, and over a range of initial values, there is only a very small difference in FMA’s favour. There’s not even a performance improvement – clearly this method wasn’t added so you can start implementing numerical root finding algorithms – the key takeaway is that the results are slightly different because a different rounding strategy has been used.

Benchmark (maxIterations) Mode Cnt Score Error Units
NM_FMA 100 thrpt 10 93.805 ± 5.174 ops/ms
NM_FMA 1000 thrpt 10 9.420 ± 1.169 ops/ms
NM_FMA 10000 thrpt 10 0.962 ± 0.044 ops/ms
NM_HandWritten 100 thrpt 10 93.457 ± 5.048 ops/ms
NM_HandWritten 1000 thrpt 10 9.274 ± 0.483 ops/ms
NM_HandWritten 10000 thrpt 10 0.928 ± 0.041 ops/ms

2 thoughts on “New Methods in Java 9: Math.fma and Arrays.mismatch

    • Newton’s method is a silly example but a matrix multiplication with FMA might be interesting to see.

      Incidentally, do you know what state Long8 in project Panama is in? Is it currently useable? I would like to evaluate using it in my fork of RoaringBitmap at some point next month.

      Like

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s