The PyTorch developer's guide to JAX fundamentals

1 week ago 8
News Banner

Looking for an Interim or Fractional CTO to support your business?

Read more

Like many PyTorch users, you may have heard great things about JAX — its high performance, the elegance of its functional programming approach, and its powerful, built-in support for parallel computation. However, you may have also struggled to find what you need to get started: a straightforward, easy-to-follow tutorial to help you understand the basics of JAX by connecting its new concepts to the PyTorch building blocks that you’re already familiar with. So, we created one! 

In this tutorial, we explore the basics of the JAX ecosystem from the lens of a PyTorch user, focusing on training a simple neural network in both frameworks for the classic machine learning (ML) task of predicting which passengers survived the Titanic disaster. Along the way, we introduce JAX by demonstrating how many things — from model definitions and instantiation to training — map to their PyTorch equivalents.

You can follow along with full code examples in the accompanying notebook: https://www.kaggle.com/code/anfalatgoogle/pytorch-developer-s-guide-to-jax-fundamentals 

Modularity with JAX

As a PyTorch user, you might initially find Jax’s highly modularized ecosystem to be quite different than what you are used to. JAX focuses on being a high-performance numerical computation library with support for automatic differentiation. Unlike with PyTorch, it does not try to have explicit built-in support for defining neural networks, optimizers, etc. Instead, JAX is designed to be flexible, allowing you to bring in your frameworks of choice to add to its functionality. 

In this tutorial, we use the Flax Neural Network library and the Optax optimization library — both very popular, well-supported libraries. We show how to train a neural network in the new Flax NNX API for a very PyTorch-esque experience, and then show how to do the same thing with the older, but still widely-used Linen API.

Functional programming

Before we dive into our tutorial, let’s talk about JAX’s rationale for using functional programming, as opposed to the object-oriented programming that PyTorch and other frameworks use. Briefly, functional programming focuses on pure functions that cannot mutate state and cannot have side effects, i.e., they always produce the same output for the same input. In JAX, this manifests through significant usage of composable functions and immutable arrays. 

The predictability of pure functions and functional programming unlocks many benefits in JAX, such as Just-In-Time (JIT) compilation, where the XLA compiler can significantly optimize code on GPUs or TPUs, for major speed-ups. Moreover, they also make sharding and parallelizing operations much easier in JAX. You can learn more from the official JAX tutorials

Do not be deterred if you're new to functional programming — as you will soon see, Flax NNX hides much of it behind standard Pythonic idioms. 

Data loading

Data loading in JAX is very straightforward — just do what you already do in PyTorch. You can use a PyTorch dataset/dataloader with a simple collate_fn to convert things to the Numpy-like arrays that underlie all JAX computation.

Model definition

With Flax’s NNX API, defining your neural networks is very similar to doing so in PyTorch. Here we define a simple, two-layer multilayer perceptron in both frameworks, starting with PyTorch.

NNX model definitions are very similar to the PyTorch code above. Both make use of __init__ to define the layers of the model, while __call__ corresponds to forward.

Model initialization and usage

Model initialization in NNX is nearly identical to PyTorch. In both frameworks, when you instantiate an instance of the model class, the model parameters are eagerly (vs. lazily) initialized and tied to the instance itself. The only difference in NNX is that you need to pass in a pseudorandom number generator (PRNG) key when instantiating the model. In keeping with Jax’s functional nature, it avoids implicit global random state, requiring you to explicitly pass PRNG keys. This makes PRNG generation easily reproducible, parallelizable, and vectorizable. See the JAX docs for more details.

Actually using the models to process a batch of data is equivalent between the two frameworks:

Training step and backpropagation 

There are some key differences in training loops between PyTorch and Flax NNX. To demonstrate, let’s build up to the full NNX training loop step by step. 

Setup

In both frameworks, we create Optimizers and have the flexibility to specify our optimization algorithm. While PyTorch requires passing in model parameters, Flax NNX allows you to just pass in the model directly and handles all interactions with the underlying Optax optimizer. 

Forward + backward pass

Perhaps the biggest difference between PyTorch and JAX is how to do a full forward/backward pass. With PyTorch, you calculate the gradients with loss.backward(), triggering AutoGrad to follow the computation graph from loss to compute the gradients. 

JAX’s automatic differentiation is instead much closer to the raw math, where you have gradients of functions. Specifically, nnx.value_and_grad/nnx.grad take in a function, loss_fn, and return a function, grad_fn. Then, grad_fn itself returns the gradient of the output of loss_fn with respect to its input. 

In our example, loss_fn is doing exactly what is being done in PyTorch: first, it gets the logits from the forward pass and then calculates the familiar loss. From there, grad_fn calculates the gradient of loss with respect to the parameters of model. In mathematical terms, the grads that are returned are ∂J/∂θ. This is exactly what is happening in PyTorch under the hood: whereas PyTorch is “storing” the gradients in the tensor's .grad attribute when you do loss.backward(), JAX and Flax NNX follow the functional approach of not mutating state and just return the gradients to you directly. 

Optimizer step

In PyTorch, optimizer.step() updates the weights in place using the gradients. NNX also does an in-place update of the weights, but requires the grads you calculated in the backward pass to be passed in directly. This is the same optimization step that is done in PyTorch, just slightly more explicit — in keeping with Jax’s underlying functional nature. 

Full training loop

You now have everything you need to construct a full training loop in JAX/Flax NNX. As a reference, let’s first see the familiar PyTorch loop:

And now the full NNX training loop:

The key takeaway is that the training loops are very similar between PyTorch and JAX/Flax NNX, with most of the differences boiling down to object-oriented versus functional programming. Although there’s a slight learning curve to functional programming and thinking about gradients of functions, it enables many of the aforementioned benefits in JAX, e.g., JIT compilation and automatic parallelization. For example, just adding the @nnx.jit annotations to the above functions speeds up training the model for 500 epochs from 6.25 minutes to just 1.8 minutes with a P100 GPU on Kaggle! You’ll see similar speedups with the same code across CPUs, TPUs, and even non-NVIDIA GPUs. 

Flax Linen reference

As previously mentioned, the JAX ecosystem is very flexible and lets you bring in your framework of choice. Although NNX is the recommended solution for new users, the Flax Linen API is still widely used today, including in powerful frameworks like MaxText and MaxDiffusion. While NNX is far more Pythonic and hides much of the complexity of state management, Linen adheres much more closely to pure functional programming. 

Being comfortable with both is greatly beneficial if you want to participate in the JAX ecosystem. To help, let’s replicate much of our NNX code with Linen, and include comments highlighting the main differences.

Next steps

With the JAX/Flax knowledge you’ve gained from this blog post, you are now ready to write your own neural network. You can get started right away in Google Colab or Kaggle. Find a challenge on Kaggle and write a brand new model with Flax NNX, or start training a large language model (LLM) with MaxText — the possibilities are endless. 

And we have just scratched the surface with JAX and Flax. To learn more about JIT, automatic vectorization, custom gradients, and more, check out the documentation for both JAX and Flax!

Posted in
Read Entire Article