I’ve only been reading through the docs for a few moments, but I’m pleasantly surprised to find they the authors are using effect handlers to handle effectful computations in ML models. I was in the process of translating a model from torch to Jax using Equinox, this makes me think penzai could be a better choice.
I remember pytorch has some pytree capability, no? So is it safe to say that the any-pytree-compatible modules here are already compatible w/ pytorch?
Does anyone know if and how well Penzai can work with Diffrax [1]? I currently use Diffrax + Equinox for scientific machine learning. Penzai looks like an interesting alternative to Equinox.
I have a small YT channel that teaches JAX bit-by-bit, check it out! https://www.youtube.com/@TwoMinuteJAX
Looks great, but outside Google I do not personally know anyone who uses Jax, and I work in this space.
I like JAX, and find most of the core functionality as an "accelerated NumPy" great. Ecosystem fragmentation and difficulties in interop make adopting JAX hard though.
There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading.
PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet.