Timezone: »
JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One: parameterised functions are themselves represented as
PyTrees', which means that the parameterisation of a function is transparent to the JAX framework. Two: we filter a PyTree to isolate just those components that should be treated when transforming (jit',
grad' or `vmap'-ing) a higher-order function of a parameterised function -- such as a loss function applied to a model. Overall Equinox resolves the above tension without introducing any new programmatic abstractions: only PyTrees and transformations, just as with regular JAX. Equinox is available at [REDACTED].
Author Information
Patrick Kidger (University of Oxford)
More from the Same Authors
-
2021 Workshop: The Symbiosis of Deep Learning and Differential Equations »
Luca Celotti · Kelly Buchanan · Jorge Ortiz · Patrick Kidger · Stefano Massaroli · Michael Poli · Lily Hu · Ermal Rrapaj · Martin Magill · Thorsteinn Jonsson · Animesh Garg · Murtadha Aldeer -
2021 Poster: Efficient and Accurate Gradients for Neural SDEs »
Patrick Kidger · James Foster · Xuechen (Chen) Li · Terry Lyons -
2020 Poster: Neural Controlled Differential Equations for Irregular Time Series »
Patrick Kidger · James Morrill · James Foster · Terry Lyons -
2020 Spotlight: Neural Controlled Differential Equations for Irregular Time Series »
Patrick Kidger · James Morrill · James Foster · Terry Lyons -
2019 Poster: Deep Signature Transforms »
Patrick Kidger · Patric Bonnier · Imanol Perez Arribas · Cristopher Salvi · Terry Lyons