In [1]:
import numpy as np

In [2]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy.stats as jspstats

In [3]:
import numpyro
import numpyro.distributions as dists
from numpyro.infer import MCMC, NUTS

In [4]:
def model(samples=None, shape=(60, 60)):
    sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))
    sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))
    mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))
    
    draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])
    draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])
    draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]
    numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)

In [5]:
def generate_grid(sigma_s, sigma_b, mu_tot, grid_size=[60, 60]):
    grid = (
        mu_tot
      + (np.random.normal(size=grid_size[0]) * sigma_s)[:, np.newaxis]
      + (np.random.normal(size=grid_size[1]) * sigma_b)[np.newaxis, :]
      + np.random.normal(size=grid_size)
    )
    grid = 1 * (grid > 0)
    return grid

grid = generate_grid(sigma_s=2.4, sigma_b=0.6, mu_tot=-1.5)

In [6]:
rng_key = jrandom.PRNGKey(0)
rng_key, rng_key_ = jrandom.split(rng_key)

kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, samples=grid
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

I0000 00:00:1700516044.602393 1486344 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-11-20 22:34:04.630229: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 12620922880
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
sample: 100%|â–ˆ| 3000/3000 [00:19<00:00, 153.28it/s, 31 steps of size 1.17e-01. a



                 mean       std    median      5.0%     95.0%     n_eff     r_hat
 draws_b[0]      0.41      0.24      0.41      0.02      0.80   2476.09      1.00
 draws_b[1]     -0.03      0.23     -0.04     -0.43      0.32   1978.01      1.00
 draws_b[2]     -0.71      0.28     -0.69     -1.15     -0.26   2136.96      1.00
 draws_b[3]      0.20      0.24      0.20     -0.23      0.56   2107.37      1.00
 draws_b[4]      0.22      0.24      0.22     -0.17      0.59   1564.28      1.00
 draws_b[5]     -0.60      0.27     -0.59     -1.09     -0.21   1970.86      1.00
 draws_b[6]      0.66      0.23      0.66      0.28      1.03   2108.18      1.00
 draws_b[7]     -0.70      0.27     -0.69     -1.14     -0.27   2018.21      1.00
 draws_b[8]     -0.80      0.28     -0.80     -1.25     -0.36   2037.66      1.00
 draws_b[9]     -0.08      0.25     -0.08     -0.51      0.30   1846.97      1.00
draws_b[10]      0.08      0.25      0.09     -0.36      0.46   2327.05      1.00
draws_b[11]    