equinox
pypi i equinox

pypi i equinox

# Equinox

Equinox is a JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees.

In doing so:

• We get a PyTorch-like API...
• ...that's fully compatible with native JAX transformations...
• ...with no new concepts you have to learn. (It's all just PyTrees.)

The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.

(In other words, why should you care? Because Equinox is really simple to learn, and really simple to use.)

## Installation

``````pip install equinox
``````

Requires Python 3.7+ and JAX 0.3.4+.

## Documentation

Available at https://docs.kidger.site/equinox.

## Quick example

Models are defined using PyTorch-like syntax:

``````import equinox as eqx
import jax

class Linear(eqx.Module):
weight: jax.numpy.ndarray
bias: jax.numpy.ndarray

def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))

def __call__(self, x):
return self.weight @ x + self.bias
``````

and fully compatible with normal JAX operations:

``````@jax.jit
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
``````

Finally, there's no magic behind the scenes. All `eqx.Module` does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

## Citation

``````@article{kidger2021equinox,
author={Patrick Kidger and Cristian Garcia},
title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
year={2021},
journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}
``````

(Also consider starring the project on GitHub.)

Numerical differential equation solvers: Diffrax.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: jaxtyping.

SymPy<->JAX conversion; train symbolic expressions via gradient descent: sympy2jax.

VersionTagPublished
0.9.2
4mos ago
0.9.1
4mos ago
0.9.0
5mos ago
0.8.0
6mos ago

## Rate & Review

100  No reviews found
Be the first to rate