Implementations of normalizing flows (RealNVP, Glow, MAF) in the JAX deep learning framework.
Normalizing flow models are generative models, i.e. they infer the underlying probability distribution of an observed dataset. With that distribution we can do a number of interesting things, namely sample new realistic points and query probability densities.
A few reasons!
1) JAX encourages a functional style. When writing a layer, I didn't want people to worry about PyTorch or TensorFlow boilerplate and how their code has to fit into "the system" (e.g. do I have to keep track of
self.training here?) All you have to worry about is writing a vanilla python function which, given an ndarray, returns the correct set of outputs. You could develop your own layers with effectively no knowledge of the encompassing framework.
2) JAX's random number generation system places reproducibility first. To get a sense for this, when you start to parallelize a system, centralized state-based models for PRNG a la
tf.random.set_seed() start to yield inconsistent results. Given that randomness is such a central component to work in this area, I thought that uncompromising reproducibility would be a nice feature.
3) JAX has a really flexible automatic differentiation system. So flexible, in fact, that you can (basically) write arbitrary python functions (including for loops, if statements, etc.) and automatically compute their jacobian with a call to
jax.jacfwd. So, in theory, you could write a normalizing flow layer and automatically compute its jacobian's log determinant without having to do so manually (although we're not quite there yet).
Here's an introduction! But for a more comprehensive description, check out the documentation.
bijection is a parameterized invertible function.
init_fun = flows.InvertibleLinear() params, direct_fun, inverse_fun = init_fun(rng, input_dim=5) # Transform inputs transformed_inputs, log_det_jacobian_direct = direct_fun(params, inputs) # Reconstruct original inputs reconstructed_inputs, log_det_jacobian_inverse = inverse_fun(params, transformed_inputs) assert np.array_equal(inputs, reconstructed_inputs)
We can construct a sequence of bijections using
flows.Serial. The result is just another bijection, and adheres to the exact same interface.
init_fun = flows.Serial( flows.AffineCoupling(transformation), flows.InvertibleLinear(), flows.ActNorm(), ) params, direct_fun, inverse_fun = init_fun(rng, input_dim=5)
distribution is characterized by a probability density querying function, a sampling function, and its parameters.
init_fun = flows.Normal() params, log_pdf, sample = init_fun(rng, input_dim=5) # Query probability density of points log_pdfs = log_pdf(params, inputs) # Draw new points samples = sample(rng, params, num_samples)
Under this definition, a normalizing flow model is just a
distribution. But to retrieve one, we have to give it a
bijection and another
distribution to act as a prior.
bijection = flows.Serial( flows.AffineCoupling(transformation), flows.InvertibleLinear(), flows.ActNorm(), ) prior = flows.Normal() init_fun = flows.Flow(bijection, prior) params, log_pdf, sample = init_fun(rng, input_dim=5)
The same as you always would in JAX! First, define an appropriate loss function and parameter update step.
def loss(params, inputs): return -log_pdf(params, inputs).mean() def step(i, opt_state, inputs): params = get_params(opt_state) gradient = grad(loss)(params, inputs) return opt_update(i, gradient, opt_state)
Then execute a standard JAX training loop.
batch_size, itercount = 32, itertools.count() for epoch in range(num_epochs): npr.shuffle(X) for batch_index in range(0, X.shape, batch_size): opt_state = step( next(itercount), opt_state, X[batch_index:batch_index+batch_size] ) optimized_params = get_params(opt_state)
Now that we have our trained model parameters, we can query and sample as regular.
log_pdfs = log_pdf(optimized_params, inputs) samples = sample(rng, optimized_params, num_samples)
Yay! Check out our contributing guidelines.
The implementations are modeled after the work of the following papers:
NICE: Non-linear Independent Components Estimation\ Laurent Dinh, David Krueger, Yoshua Bengio\ arXiv:1410.8516
Density estimation using Real NVP\ Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio\ arXiv:1605.08803
Improving Variational Inference with Inverse Autoregressive Flow \ Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling\ arXiv:1606.04934
Glow: Generative Flow with Invertible 1x1 Convolutions\ Diederik P. Kingma, Prafulla Dhariwal\ arXiv:1807.03039
Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design\ Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel\ OpenReview:Hyg74h05tX
Masked Autoregressive Flow for Density Estimation\ George Papamakarios, Theo Pavlakou, Iain Murray\ arXiv:1705.07057
Neural Spline Flows\ Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios\ arXiv:1906.04032
And by association the following surveys:
Normalizing Flows: An Introduction and Review of Current Methods\ Ivan Kobyzev, Simon Prince, Marcus A. Brubaker\ arXiv:1908.09257
Normalizing Flows for Probabilistic Modeling and Inference\ George Papamakarios, Eric Nalisnick, Danilo Jimenez Rezende, Shakir Mohamed, Balaji Lakshminarayanan\ arXiv:1912.02762