JAX can make ML programming more intuitive, systematic, and organised. Although having a fundamentally different design, it could replace Tensorflow and PyTorch.

Scientists who study machine learning are excited about JAX because it makes programming for machine learning easy, structured, and clean. In addition, it has a system of composable function transformations that researchers in machine learning can use.

Today, let's take a look at some new JAX libraries:

NumPyro

NumPyro is a small library for probabilistic programming that gives Pyro a NumPy backend. They depend on JAX for automatic differentiation and JIT compilation to GPU/CPU. NumPyro is still being worked on, so bugs and changes to the API are possible as the design changes.

Jax-unirep

The UniRep model was made in the lab of George Church. You can see the original publication here (bioRxiv) or here (Nature Methods), as well as the repository where the original model is kept.

This repository is a version of the UniRep model that is self-contained and easy to change. It also has extra utility APIs that help protein engineering workflows.

TensorLy

TensorLy is a Python library that aims to make learning about tensors easy. It makes tensor decomposition, tensor learning, and tensor algebra easy. Its backend system lets you use NumPy, PyTorch, JAX, MXNet, TensorFlow, or CuPy to do computations and run methods at scale on the CPU or GPU.

Fortuna

Fortuna is a library for measuring uncertainty that makes it easy for users to run benchmarks and add delay to systems used in production. Fortuna has calibration and conformal methods that can be used with models that have already been trained and were written in any framework. It also has several Bayesian inference methods that can be used with deep learning models written in Flax. The language is easy to understand for people who need to learn more about uncertainty quantification, and it can be set up in many different ways.

Cvxpylayers

cvxpylayers is a Python library for building differentiable convex optimisation layers in PyTorch, JAX, and TensorFlow using CVXPY. In the forward pass, a convex optimisation layer solves a parameterised convex optimisation problem to find a solution. In the backward pass, it finds the solution's derivative in terms of the parameters.

This library goes with our paper about different convex optimisation layers for NeurIPS 2019.

Optax

Optax is a library for JAX that lets you process and optimise gradients. It is made to make research easier by giving researchers building blocks that can be easily put together in different ways.

Their purpose is to

  • Provide implementations of core components that are simple, well-tested, and effective.
  • Increase research efficiency by making it easy to combine low-level components into custom optimisers (or other gradient processing components).
  • Make it easy for anyone to contribute, and new ideas will spread faster.

Equivariant MLP

Equivariant MLP(EMLP) is a Practical Approach for Building Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups. EMLP is a Jax library for automatically producing equivariant layers in deep learning.

It does the following features:

  • Compute equivariant linear layers between representations of finite dimensions. You give the symmetry group (discrete, continuous, non-compact, complex) and the representations (tensors, irreducibles, induced representations, etc.). 
  • Automated generation of comprehensive equivariant models using small data: For instance, if your inputs and outputs (and intended features) are a small collection of elements such as scalars, vectors, tensors, and irreps with a total dimension of less than 1000, you will likely be able to use EMLP as a turnkey solution for creating the model, or at the very least as a solid baseline.

Want to publish your content?

Publish an article and share your insights to the world.

ALSO EXPLORE

DISCLAIMER

The information provided on this page has been procured through secondary sources. In case you would like to suggest any update, please write to us at support.ai@mail.nasscom.in