JAX-CFD: Computational Fluid Dynamics in JAX

Authors: Dmitrii Kochkov, Jamie A. Smith, Stephan Hoyer

JAX-CFD is an experimental research project for exploring the potential of machine learning, automatic differentiation and hardware accelerators (GPU/TPU) for computational fluid dynamics. It is implemented in JAX.

To learn more about our general approach, read our paper Machine learning accelerated computational fluid dynamics (PNAS 2021).

Getting started

Take a look at "demo" notebook in the "notebooks" directory for an example of how to use JAX-CFD for simulate a 2D turbulent flow.

We are currently preparing more example notebooks, inculding:

  • Reusing our training data and/or evaluation setup (without running JAX-CFD)
  • Simulations using our pre-trained turbulence models.
  • Training a simple hybrid ML + CFD model from scratch


JAX-CFD is organized around sub-modules:

  • jax_cfd.base: core numerical methods for CFD, written in JAX.
  • jax_cfd.ml: machine learning augmented models for CFD, written in JAX and Haiku.
  • jax_cfd.data: data processing utilities for preparing, evaluating and post-processing data created with JAX-CFD, written in Xarray and Pillow.

A base install with pip install jax-cfd only requires NumPy, SciPy and JAX. To install dependencies for the other submodules, use pip install jax-cfd[ml], pip install jax-cfd[data] or pip install jax-cfd[complete].


JAX-CFD is currently focused on unsteady turbulent flows:

  • Spatial discretization: Finite volume/difference methods on a staggered grid (the "Arakawa C" or "MAC" grid) with pressure at the center of each cell and velocity components defined on corresponding faces.
  • Temporal discretization: Currently only first-order temporal discretization, using explicit time-stepping for advection and either implicit or explicit time-stepping for diffusion.
  • Pressure solves: Either CG or fast diagonalization with real-valued FFTs (suitable for periodic boundary conditions).
  • Boundary conditions: Currently only periodic boundary conditions are supported.
  • Advection: We implement 2nd order accurate "Van Leer" schemes.
  • Closures: We currently implement Smagorinsky eddy-viscosity models.

TODO: add a notebook explaining our numerical models in more depth.

In the long term, we're interested in expanding JAX-CFD to implement methods relevant for related research, e.g.,

  • Spectral methods
  • Colocated grids
  • Alternative boundary conditions (e.g., non-periodic boundaries and immersed boundary methods)
  • Higher order time-stepping
  • Geometric multigrid
  • Steady state simulation (e.g., RANS)
  • Distributed simulations across multiple TPUs/GPUs

We would welcome collaboration on any of these! Please reach out (either on GitHub or by email) to coordinate before starting significant work.

Projects using JAX-CFD

Other awesome projects

Other differentiable CFD codes compatible with deep learning:

JAX for science:

Did we miss something? Please let us know!


Local development

To locally install for development:

git clone https://github.com/google/jax-cfd.git
cd jax-cfd
pip install jaxlib
pip install -e ".[complete]"

Then to manually run the test suite:

pytest -n auto jax_cfd --dist=loadfile --ignore=jax_cfd/base/validation_test.py

