JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.
Most JAX usage is through the familiar jax.numpy API, which is typically imported under the jnp alias:
With this import, you can immediately use JAX in a similar manner to typical NumPy programs, including using NumPy-style array creation functions, Python functions and operators, and array attributes and methods:
You’ll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; these are explored in 🔪 JAX - The Sharp Bits 🔪.
JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the jax.jit() function to compile this sequence of operations together using XLA.
We can use IPython’s %timeit to quickly benchmark our selu function, using block_until_ready() to account for JAX’s dynamic dispatch (See Asynchronous dispatch ):