Saturday, October 5, 2024

Optimising GPU code

I complained to Juan Fumeron that a benchmark indicated that the GPU was not giving much of a performance improvement. JMH reported the GPU being a moderate 20% faster than the CPU:

tornadojar tornado-benchmarks/target/jmhbenchmarks.jar uk.ac.manchester.tornado.benchmarks.sgemv.JMHSgemV
...
Benchmark              Mode  Cnt         Score         Error  Units
JMHSgemV.sgemVJava     avgt    5  72366270.751 ± 5916807.539  ns/op
JMHSgemV.sgemVTornado  avgt    5  57583087.103 ± 2523449.341  ns/op

(SGEMM is single precision general matrix multiplication. GEMV indicates that we're multiplying a matrix with a vector).

Juan replied that I should try TornadoVM's  --enableProfiler console switch and see where the time was being spent. Sure enough, COPY_IN_TIME was ~28ms, about the same as TOTAL_KERNEL_TIME.

Note that the total kernel time is the time it takes the GPU to perform the computation and the total kernel dispatch time is the time it takes to schedule the kernel (ie, the function being executed). In this case, dispatch time is ~6us - three orders of magnitude smaller than the execution time.

Juan also said that "Matrix Vector is not as compute intensive as other applications", so instead I tried the matrix/matrix multiplication. Here, the GPU shines:

Benchmark              Mode  Cnt           Score         Error  Units
JMHSgemm.sgemmJava     avgt    5  1773297262.188 ± 4115731.439  ns/op
JMHSgemm.sgemmTornado  avgt    5     8478409.506 ±  246919.368  ns/op

That makes the GPU 200 times faster than the CPU. Now COPY_IN_TIME is about 1ms and TOTAL_KERNEL_TIME is about 5.5ms.

Now we're talking. But continuing this optimization rampage, it's worth noting that "It has become tribal knowledge that the particular shapes chosen for matmuls has a surprisingly large effect on their performance." [Horace He] TL;DR; He's article explains how fitting the small memory tiles onto a large matrix can hugely change performance - basically, that in a row-major MxN matrix, N must be a factor of the GPU's cache line for best results.

Changes in Java's memory

In the old days, we'd use sun.misc.Unsafe.allocateMemory to use off-heap memory. This code goes straight to the OS and asks for memory via os::realloc. But using Unsafe is bad practise. Not only is it specific to a particular flavout of JVM, it allows access to raw memory. The latter is "fine" if that memory is off-heap but if you are using it to access a Java object, the garbage collector can change its memory location without warning.

There are several modern alternatives. Since Java 9, java.lang.invoke.VarHandle has been the recommended alternative. It provides the same level of low-level access as Unsafe but with better safety and control over memory visibility. That is, its memory access patterns apparently offer finer grained control - eg, volatile access without enforcing strict instruction ordering. 

It's interesting to note that the high performing interoperability framework, Apache Arrow, does not use VarHandle. It still uses Unsafe as VarHandle has bound checking etc that is slower than raw access. 

Since Java 20, we've had Project Panama's Foreign Function & Memory API (JEP-424) spec (it appears Apache Arrow doesn't use it because it's too new). If we run this code:

MemorySegment memorySegment = Arena.global().allocate(1024 * 1024 * 128, 8);         System.out.println(memorySegment.address());

then look for the address while it's still running in /proc/PID/maps (where PID is the ID of the Java process), we can see that the Linux OS now manages a new area of memory. For instance, when I ran it, the output was 0x7fbaccdbe010 and I can see in the maps pseudo file:

7fbaccdbe000-7fbad4dbf000 rw-p 00000000 00:00 0 

This represents the 128 megs of space plus 4096 bytes (presumably a page for meta data).

Now, since Java and C/C++ are IEEE 754 compliant, and now they can pass native memory to each other, you can transparently pass floating point numbers between code bases and run the C/C++ program in the JVM - no more need for JNI! (Interestingly, note that Python is often IEEE754 compliant but it is not guaranteed to be).

It's interesting to note that the GPU enabled Tornado VM uses the java.lang.foreign package to move data to and from the GPU.