elegy

Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Showing:

Popularity

Downloads/wk

0

GitHub Stars

260

Maintenance

Last Commit

1mo ago

Contributors

16

Package

Dependencies

13

License

APACHE

Categories

Readme

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that maximizes flexibility when needed.
  • Agnostic: Elegy supports various frameworks, including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume many familiar data sources, including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information, take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users, we recommend the Windows subsystem for Linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a straightforward interface you can use by implementing the following steps:

1. Define the architecture inside a Module. We will use Flax Linen for this example:

import flax.linen as nn
import jax

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import elegy, optax

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

Elegy's low-level API lets you explicitly define what goes on during training, testing, and inference. Let's define the test_step to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True, we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples, first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py

Contributing

Deep Learning is evolving at an incredible pace, and there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy, open an issue or send a PR! For more information, check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

This project is licensed under the Apache v2.0 License.

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.7.4},
year = {2020},
}

The current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Rate & Review

Great Documentation0
Easy to Use0
Performant0
Highly Customizable0
Bleeding Edge0
Responsive Maintainers0
Poor Documentation0
Hard to Use0
Slow0
Buggy0
Abandoned0
Unwelcoming Community0
100