jax==0.4.16 numpy==1.25.2 numpy==1.21.5 numpyro==0.11.0