JAX Criblog

JAX Criblog Overview JAX implements the Numpy and Scipy APIs, which allows us to do numerical operations on tensor-like arrays. A JAX DeviceArray is essentially this object containing the following: Numpy value dtype Like normally, you can access . shape via the object’s attribute. import jax JAX can be executed on most accelerators (CPU/GPU/TPU). Ensure that you install JAX with GPU support and set your CUDA lib path correctly. You should no longer see: WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) To properly time function executions, simply call block_until_ready on any JAX op, e.g. >>> t = np.arange(1e7) >>> %timeit, t).block_until_ready() 100 loops, best of 5: 16 ms per loop RNG mangement The cool part is that the user manages all the (non-)determinism! For any jax.random op, pass in this array object known as an RNG key. A good practise is to

Mean Field Variational Inference

Hacking OSX Catalina in < 10.15 Commands

Sparsifying Neural Networks