JAX Criblog

JAX Criblog

Overview

JAX implements the Numpy and Scipy APIs, which allows us to do numerical operations on tensor-like arrays. A JAX DeviceArray is essentially this object containing the following:

  • Numpy value
  • dtype

Like normally, you can access .shape via the object’s attribute.

import jax

JAX can be executed on most accelerators (CPU/GPU/TPU). Ensure that you install JAX with GPU support and set your CUDA lib path correctly. You should no longer see:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

To properly time function executions, simply call block_until_ready on any JAX op, e.g.

>>> t = np.arange(1e7)
>>> %timeit np.dot(t, t).block_until_ready()
100 loops, best of 5: 16 ms per loop

RNG mangement

The cool part is that the user manages all the (non-)determinism!
For any jax.random op, pass in this array object known as an RNG key. A good practise is to replace the key after each split to ensure reproducibility:

>>> key = jax.random.PRNGKey(0)
>>> key, *subkeys = jax.random.split(key, 3)
>>> n1 = jax.random.normal(subkeys[0],  (1,))
DeviceArray([0.5781487], dtype=float32)
>>> n2 = jax.random.normal(subkeys[1],  (1,))
DeviceArray([0.85355157], dtype=float32)

XLA-esque things

To be JAX-onic, one must apply transformations on functions without side effects (e.g. functionally pure). Since it is XLA-compiled in the backend, there are some things that you cannot do like you would normally in Python.

Notably, xx[-1, -1] = 5 now becomes:

>>> jax.ops.index_update(xx,  (-1,  -1),  5)

which is similar to the effects of @tf.function in Tensorflow where data structures are immutable.

import jax.numpy as np; import numpy as onp

Upon importing JAX Numpy, one can easily access all the standard operations:

>>> x = np.diag(np.arange(10))
DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
>>> x + x
DeviceArray([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=int32)
>>> xx = np.ones((4, 4))
>>> np.linalg.block_diag(xx)
DeviceArray([[1., 1., 0., 0., 0., 0.], 
             [1., 1., 0., 0., 0., 0.], 
             [0., 0., 1., 1., 1., 0.], 
             [0., 0., 1., 1., 1., 0.], 
             [0., 0., 1., 1., 1., 0.], 
             [0., 0., 0., 0., 0., 1.]], dtype=float32)

If you want to define your own ops, check out lax.ops for primitive operations that are XLA-compiled, and that are optimized to run in JAX’s non-static function decorators.

import jax.grad as grad

Taking gradients of functions (e.g. in the parameter update step of gradient descent optimization) using JAX’s autodiff package is super easy and clear!
Note: grad requires a float or complex input, cast accordingly.

def fn(a, b):
	return np.sum((a - b)**2)
grad_fn = grad(fn) # i.e. f'(x) = 2(a - b)
value_grad_fn = jax.value_and_grad(fn)
partial_fn = grad(fn, argnums=(0, 1))

Unlike GradientTape in Tensorflow or .backward() in Pytorch, you can easily take derivatives

>>> a = np.arange(1, 5).astype(float)
>>> b = a + 0.1
>>> dfn_da = grad_fn(a, b)
>>> dfn_da
DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32)
>>> value, dfn_da = value_grad_fn(a, b)
>>> value, dfn_da
(DeviceArray(0.03999995, dtype=float32), 
DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))

or decide which arguments to differentiate with respect to:

>>> del_a, del_b = partial_fn(a, b)
>>> del_a, del_b
(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32), 
DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))

To turn off gradients, use lax.stop_gradient and gradient signals will not be propagated:

def loss_fn(x, y):
	return (jax.lax.stop_gradient(x - y) - (x * y)) ** 2

Advanced differentiation

Aside from jax.jacobian or jax.hessian, in jax.experimental you can take higher order derivatives using jet!
For additional control and precision, you can compute forward-mode and backward-mode Jacobian-vector products of your fn via jax.jvp and jax.jvp or differentiate forward or reverse via jax.jacfwd and jax.jacrev, respectively.

from jax import jit

Here’s all the hype about speed-ups you have been hearing about. With jit, our code is XLA-compiled and ran on the accelerators of our choice!

Simply wrap a function with the @jit decorator to mark it for compilation. Example from JAX docs:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

>>> x = np.random.randn(3, 4)
>>> y = np.random.randn(4)
>>> f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>

DeviceArray([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)

Gotchas

Because of the compilation, you cannot condition on input values but only on input shape and type. In those cases, you can either manually decide how to call a function dependent on the input type via lax.cond, or pass in static_argnums when defining the decorator to prevent triggering a re-compilation for each new input to that variable:

from functools import partial

@partial(jit, static_argnums=(3,))
def some_fn(x, y, rng, bool):
	return (x + y) if bool else (x - y)

Notice as well that jitted arrays are Traced objects, meaning you cannot print the values of the data, but can only see the shape and dtype. This explains why things can be so efficient because the Python code does not need to be re-executed with every new input.

from jax import vmap

Here’s the best part: ‘vectorized’ operations. You can very nicely skip on the for-loops and efficiently compute results for batched inputs:

def fn(x):
	return np.dot(x, x)
>>> input = np.arange(4)
>>> batched_input = np.stack([input, input])
>>> batched_fn = jax.vmap(fn)
>>> fn(input)
DeviceArray(14, dtype=int32)
>>> batched_fn(batched_input)
DeviceArray([14, 14], dtype=int32)

This is opposed to the non-vectorized approach: manually looping over the batch dimension of the inputs, performing the fn operation for that dimension, concatenating the results for each batch dimension in an array, and returning the final array of stacked outputs.

Specify the axis on which to vectorize via in_axes and the axis on which to output the batched results via out_axes.

Now knowing grad, jit and vmap, get per-example (instead of accumulated) gradients of your loss function easily by composing / nesting operations: jit(vmap(grad(fn))).

import pmap

Similarly, there is pmap to actualize parallel computations on separate devices (with implicit jit compilation too). Since different parts of the batched inputs will be on different devices, consider pooling outputs from multiple devices using collective operations jax.lax.p*. Without dealing with host-host communication, specify axis_name so that collective ops can refer to the axes bound by jax.pmap (0 by default, but use different name for each different additional axis) and do the cross-device pooling on the specific operation (e.g. jax.lax.psum in this example).
Note: specifying axis_name without calling pmap will not have any effect.

def  normalized_convolution(x, w):
	output = []
	for i in  range(1,  len(x)-1):
		output.append(jnp.dot(x[i-1:i+2], w))
		output = jnp.array(output)
	return output / jax.lax.psum(output, axis_name='anything')
>>> x = np.arange(5)
>>> w = np.array([2.,  3.,  4.])
>>> n_devices = jax.local_device_count()
>>> xs = np.arange(5 * n_devices).reshape(-1,  5)
>>> ws = np.stack([w] * n_devices)
>>> jax.pmap(normalized_convolution, axis_name='anything')(xs, ws)
ShardedDeviceArray([[0.00816024, 0.01408451, 0.019437 ], 
					[0.04154303, 0.04577465, 0.04959785], 
					[0.07492582, 0.07746479, 0.07975871], 
					[0.10830861, 0.10915492, 0.10991956], 
					[0.14169139, 0.14084506, 0.14008042], 
					[0.17507419, 0.17253521, 0.17024128], 
					[0.20845698, 0.20422535, 0.20040214], 
					[0.24183977, 0.23591548, 0.23056298]], dtype=float32)

Debugging etc.

Device transfer

Do plotting and some tensor post-processing ops easier on CPU rather than the GPU. Easily move a DeviceArray to host device via jax.device_get(x) or onp.array(x) and back to accelerator via jax.device_put(x) or np.asarray(x).

Printing

Remember how difficult it is to print() in Tensorflow? Same thing happens in JAX, you’ll just see a bunch of traced objects with no information. To get information from objects on devices within traced operations (e.g. vmap, pmap), simply add from jax.experimental.host_callback import id_print.

Auxiliary info

Besides just the function’s return, you can also return extra information:

def fn_with_many_returns(a, b):
	return -np.sum(a * np.log(b)), np.mean(a == b)
>>> a = np.asarray([1., 0., 1., 1.])
>>> b = jax.nn.softmax(a)
>>> ce, acc = grad(fn_with_many_returns, has_aux=True)(a, b)
>>> print(ce, ',', acc)
[1.2142833, 2.2142832, 1.2142833, 1.2142833], 0.25

from jax.config import config

The JAX compile logger has useful settings that can be turned on and off as sanity checks when debugging. Be sure to turn these off when running your actual experiment so things don’t get too slow from host-device communication.

Numerical instability

To make the program error out the moment a nan value appears, turn on the floating point checker via: config.update("jax_debug_nans", True).

For higher floating point precision, config.update("jax_enable_x64", True).

Check jit

One sanity check to ensure that you’ve jitted things correctly is to ensure that you don’t see any compile logs after the first iteration of your training loop: config.update('jax_log_compiles', True). The first time your runtime will be slower, but once things are compiled it should be a lot faster!

If you want to view JAXPR values, simply turn off jitting globally: config.update('jax_disable_jit', True).

Profiling

You can use TensorBoard profiler to debug OOM errors or visualize your program’s memory usage. Simply insert jax.profiler.start_trace(<tb_logdir>), run the code, and capture the trace with jax.profiler.stop_trace(). Refer to the instructions from the official JAX documentation.

Resources

JAX Gotchas
Autodiff Cookbook
JIT mechanisms

Optax (Modern JAX optimizers)
Flax (nice, composable NN layers)
Haiku (non-functional API)

Blog by my friend Joao (in Portugese!)
awesome-jax by my friend Nick

Comments