NumPyro is a lightweight probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU. NumPyro is under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.
Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools.
, where y are the treatment effects and sigma the standard error. We build a hierarchical model for the study where we assume that the group-level parameters theta for each school are sampled from a Normal distribution with unknown mean mu and standard deviation tau, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by theta (true effect) and sigma, respectively. This allows us to estimate the population-level parameters mu and tau by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level theta parameters.
Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the extra_fields argument in MCMC.run. By default, we only collect samples from the target (posterior) distribution when we run inference using MCMC. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the extra_fields argument. For a list of possible fields that can be collected, see the HMCState object. In this example, we will additionally collect the potential_energy for each sample.