In [1]:
import numpy as np

In [2]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy.stats as jspstats

In [3]:
import numpyro
import numpyro.distributions as dists
from numpyro.infer import MCMC, NUTS

In [4]:
def model(samples=None, shape=(60, 60)):
    sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))
    sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))
    mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))
    
    draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])
    draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])
    draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]
    numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)

In [5]:
sigma_s = 2.5
sigma_b = 0.6
mu_tot = -1.

grid_size = [60, 60]
grid = (
    mu_tot
  + (np.random.normal(size=grid_size[0]) * sigma_s)[:, np.newaxis]
  + (np.random.normal(size=grid_size[1]) * sigma_b)[np.newaxis, :]
  + np.random.normal(size=grid_size)
)
thresholded_grid = 1. * (grid > 0)

In [6]:
rng_key = jrandom.PRNGKey(0)
rng_key, rng_key_ = jrandom.split(rng_key)

kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, samples=thresholded_grid
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

I0000 00:00:1700508193.049216 1423603 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-11-20 20:23:13.233156: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 12620922880
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
sample: 100%|â–ˆ| 3000/3000 [00:24<00:00, 124.55it/s, 31 steps of size 9.43e-02. a



                 mean       std    median      5.0%     95.0%     n_eff     r_hat
 draws_b[0]     -0.12      0.25     -0.12     -0.55      0.26   1518.82      1.00
 draws_b[1]      0.01      0.23      0.01     -0.32      0.40   1139.98      1.00
 draws_b[2]      0.06      0.24      0.07     -0.32      0.47   1207.20      1.00
 draws_b[3]      0.67      0.24      0.67      0.26      1.05   1372.13      1.00
 draws_b[4]     -0.12      0.25     -0.13     -0.55      0.27    958.14      1.00
 draws_b[5]     -0.85      0.27     -0.84     -1.32     -0.45   1090.23      1.00
 draws_b[6]      0.08      0.24      0.08     -0.34      0.44   1572.24      1.00
 draws_b[7]      1.06      0.24      1.05      0.65      1.45   1080.15      1.00
 draws_b[8]     -0.39      0.25     -0.39     -0.80      0.02   1451.30      1.00
 draws_b[9]      0.45      0.25      0.45      0.04      0.86   1228.30      1.00
draws_b[10]     -0.19      0.26     -0.19     -0.61      0.23   1105.58      1.00
draws_b[11]    