Jax Sourceror is a Python library that allows you to recreate JAX source code from a jitted jax function (specifically its jaxpr) and a set of inputs.

Search code, repositories, users, issues, pull requests...

submited by
Style Pass
2024-04-24 07:30:05

Jax Sourceror is a Python library that allows you to recreate JAX source code from a jitted jax function (specifically its jaxpr) and a set of inputs. This is useful for minimizing bugs, debugging, teaching, and understanding how JAX works under the hood.

The code this generates is definitely not going to be clean, idiomatic, or sometimes even correct, but it should be a good starting point for understanding what's going on.

I created it mostly as a learning exercise and to minimize bugs in framework-heavy code (i.e. removing layers of equinox or flax abstraction to get to the JAX code).

This is more of a "submit a PR" or "fork it" repo than a "this doesn't work for me" repo, but I'm happy to help out if you're stuck.

Is this pretty code? No. Is it even readable? If you try hard enough. Is it correct? I think so. (It definitely passes my unit test!)

Leave a Comment