How To speed up your python code with JAX

Machine Learning 5 min read

You can find the code used in this article on colab

Ottawa road in the evening
Photographer: Marc-Olivier Jodoin

Recently, it seems that everyone is talking about this new library: JAX.

Sample tweet from Keras creator

But what is all the fuss about? Is it a new state-of-the-art machine learning interface? A new way to scale up your applications? From the official documentation:

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Maybe the best way to understand what it does is to list what one look at when choosing a new tool for prototyping new ideas:

  • Easy to use: hop right in and write some code in a matter of seconds;
  • Fast: optimized and running quickly to test things easily;
  • Powerful: provide some awesome features that are not seen elsewhere;
  • Understandable: syntax is immediately clear and we know what our code is doing.

JAX meets all of these criteria. It’s a scientific python library that is more like a toolbox for prototyping ideas. It:

  • Implements a NumPy API so it’s easy to use if you already know NumPy;
  • Allows for JIT compilation of your code, automatic vectorization and parallelization of functions for batch computing, and it runs on CPU, GPUs, and TPUs transparently, so it’s fast;
  • Provides an autograd feature to compute the gradient of pretty much any function

JAX as a NumPy API

Let’s say you have been working on some code using NumPy as its scientific library backend, and you just want to differentiate this code, or you want it to run on the GPU.

import jax as jx
import numpy as np # Don't use this
import jax.numpy as jnp # Cool kids do this !

It’s as simple as that!

To be fair, your code has to respect a few constraints for it to work.

  • You can only use pure functions : if you call your function twice it has to return the same result and you can’t do in-place updates of arrays.
array = jnp.zeros((2,2))
array[:, 0] = 1 # Won't work
array = jax.ops.index_update(array, jax.ops.index[:, 0], 1) # Better
  • Random number generation is explicit and it uses a PRNG key
key = jx.random.PRNGKey(1)
x = jx.random.normal(key=key, shape=(1000,))

But apart from that, you can still use it to do the same operations as in NumPy.

def numpy_softmax(x: np.array) -> np.array:
exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum(axis=0)

def jax_softmax(x: jnp.array) -> jnp.array:
exp_x = jnp.exp(x - jnp.max(x))
return exp_x / exp_x.sum(axis=0)

Make it go faster

speed with jax
Example of python code using JAX

One of the cool features of JAX is its ability to compile your python code using XLA. It’s called Just-In-Time (JIT) compilation and it’s basically just caching some code that you use often so that it runs faster.

It’s easy to use, just decorate your function with jax.jit or call jax.jit on your function. We use timeit to time our execution.

%timeit jax_softmax(x)
%timeit jx.jit(jax_softmax)(x)
%timeit numpy_softmax(x)
1000 loops, best of 5: 1.15 ms per loop
1000 loops, best of 5: 296 µs per loop
1000 loops, best of 5: 646 µs per loop

And voila!


Automatic differentiation

With machine learning or optimization in mind, you might want to be able to quickly compute the gradient of pretty much any function. Lucky you, that’s exactly what JAX is capable of! Use the jax.grad operator, you can even call it multiple times on the same function.

import matplotlib.pyplot as plt

cos_grad = jx.grad(jnp.cos)
cos_grad_grad = jx.grad(cos_grad)
cos_grad_grad_grad = jx.grad(cos_grad_grad)

x = jnp.linspace(-jnp.pi,jnp.pi, 1000)
plt.plot(x, jnp.cos(x))
plt.plot(x, [cos_grad(i) for i in x])
plt.plot(x, [cos_grad_grad(i) for i in x])
plt.plot(x, [cos_grad_grad_grad(i) for i in x])

Vectorization

If you ran the previous code in the colab notebook, you may have found that it took an awful lot of time to execute. So why not use one more feature of JAX to make it faster? jax.vmap allows you to vectorize any function, making it useful when you want to batch operations for instance.

cos_grad_vec = jx.vmap(cos_grad)
%timeit [cos_grad(i) for i in x]
%timeit cos_grad_vec(x)
1 loop, best of 3: 1.66 s per loop
100 loops, best of 3: 2.41 ms per loop

It’s even better than this because you can specify on which argument of the function you want the vectorization to happen, so batching is as simple as ever.

def loss(x, constant):
return jnp.dot(x.T, x) + constant

batched_loss = jx.vmap(loss, in_axes=(0, None), out_axes=0)

JAX: Where to go next?

We saw that JAX has many powerful features that make it the go-to tool when wanting to write scientific python that runs fast. It’s super flexible and should be your

What JAX was lacking for a long time is a rich ecosystem. With recent projects starting to mature, we are seeing more and more library and projects centered around JAX, such as:

  • Haiku: a simple functional neural network library for JAX
  • FLAX: An object oriented neural network library and ecosystem for JAX designed for flexibility
  • TRAX: end-to-end library for deep learning that focuses on clear code and speed
  • Optax: Optimization and gradient transformations

If it was released in 2018, JAX has gained traction in 2020 and its use is spreading fast. Google has been advocating for it for a while, and more and more projects are powered by it! Like the implementation of the An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper. So now is a good time to try it all for yourself!

Tech AI JAX
Get our Battle-Tested Tutorials Delivered Straight Into Your Inbox Every Week
Sign up for our newsletter