You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
15 KiB
Plaintext
257 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "33c24830-7601-4c52-8fa0-66330801d6a9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "985106e5-7a80-4d4d-b1c9-54a873f0b429",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax.numpy as jnp\n",
|
|
"import jax.random as jrandom\n",
|
|
"import jax.scipy.stats as jspstats"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8a8cc9de-9802-4761-9bba-7b77300d0bb8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpyro\n",
|
|
"import numpyro.distributions as dists\n",
|
|
"from numpyro.infer import MCMC, NUTS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "9b4c0d3e-2889-4bf3-9871-6648b5de3508",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def model(samples=None, shape=(60, 60)):\n",
|
|
" sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))\n",
|
|
" sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))\n",
|
|
" mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))\n",
|
|
" \n",
|
|
" draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])\n",
|
|
" draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])\n",
|
|
" draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]\n",
|
|
" numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "caaf0f54-205e-4d13-84b6-464c1583d432",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"s_s = 2.5\n",
|
|
"s_b = 0.6\n",
|
|
"m_t = -1.\n",
|
|
"\n",
|
|
"grid_size = [60, 60]\n",
|
|
"grid = (np.random.normal(size=grid_size[0]) * s_s)[:, np.newaxis] + (np.random.normal(size=grid_size[1]) * s_b)[np.newaxis, :] + np.random.normal(size=grid_size)\n",
|
|
"thresholded_grid = 1. * (grid + m_t > 0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "06b670cf-73bb-40fd-b152-b111b13b9d50",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
|
|
"I0000 00:00:1700507813.984381 1419943 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n",
|
|
"sample: 100%|█| 3000/3000 [00:42<00:00, 70.76it/s, 31 steps of size 1.05e-01. ac\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
" mean std median 5.0% 95.0% n_eff r_hat\n",
|
|
" draws_b[0] 0.54 0.27 0.54 0.07 0.95 1726.16 1.00\n",
|
|
" draws_b[1] 0.32 0.25 0.31 -0.09 0.71 1369.21 1.00\n",
|
|
" draws_b[2] 0.68 0.26 0.69 0.25 1.09 1518.55 1.00\n",
|
|
" draws_b[3] -0.22 0.27 -0.22 -0.65 0.22 1533.48 1.00\n",
|
|
" draws_b[4] 0.23 0.26 0.23 -0.20 0.65 1078.14 1.00\n",
|
|
" draws_b[5] 0.10 0.27 0.10 -0.35 0.52 1411.39 1.00\n",
|
|
" draws_b[6] -1.09 0.31 -1.09 -1.59 -0.58 1922.28 1.00\n",
|
|
" draws_b[7] 0.58 0.25 0.58 0.19 1.00 1254.93 1.00\n",
|
|
" draws_b[8] -0.22 0.27 -0.22 -0.67 0.19 1412.75 1.00\n",
|
|
" draws_b[9] 0.55 0.25 0.55 0.13 0.96 1191.63 1.00\n",
|
|
"draws_b[10] 0.26 0.27 0.26 -0.18 0.68 1384.75 1.00\n",
|
|
"draws_b[11] 0.34 0.25 0.34 -0.06 0.75 1880.85 1.00\n",
|
|
"draws_b[12] 0.02 0.26 0.02 -0.42 0.43 1484.88 1.00\n",
|
|
"draws_b[13] 0.10 0.28 0.10 -0.34 0.55 1793.79 1.00\n",
|
|
"draws_b[14] 0.00 0.27 0.00 -0.43 0.46 1575.12 1.00\n",
|
|
"draws_b[15] 0.10 0.25 0.11 -0.28 0.54 1463.69 1.00\n",
|
|
"draws_b[16] 0.37 0.25 0.37 -0.03 0.78 1831.50 1.00\n",
|
|
"draws_b[17] 0.79 0.25 0.79 0.40 1.22 1244.41 1.00\n",
|
|
"draws_b[18] -0.19 0.27 -0.18 -0.62 0.27 1036.48 1.00\n",
|
|
"draws_b[19] 0.49 0.25 0.49 0.11 0.95 1760.80 1.00\n",
|
|
"draws_b[20] -0.79 0.29 -0.78 -1.26 -0.33 1345.88 1.00\n",
|
|
"draws_b[21] 0.01 0.26 0.01 -0.41 0.45 1503.33 1.00\n",
|
|
"draws_b[22] 0.79 0.25 0.79 0.40 1.20 1519.65 1.00\n",
|
|
"draws_b[23] -1.59 0.35 -1.58 -2.20 -1.05 1248.30 1.00\n",
|
|
"draws_b[24] 1.19 0.24 1.20 0.78 1.58 1363.44 1.00\n",
|
|
"draws_b[25] 0.06 0.26 0.05 -0.34 0.51 1344.98 1.00\n",
|
|
"draws_b[26] 0.11 0.26 0.11 -0.31 0.53 1329.94 1.00\n",
|
|
"draws_b[27] 0.13 0.25 0.13 -0.30 0.52 1586.63 1.00\n",
|
|
"draws_b[28] -0.70 0.29 -0.69 -1.15 -0.21 1783.14 1.00\n",
|
|
"draws_b[29] 1.32 0.23 1.33 0.94 1.71 1445.95 1.00\n",
|
|
"draws_b[30] 0.04 0.26 0.04 -0.38 0.46 1281.60 1.00\n",
|
|
"draws_b[31] -0.46 0.28 -0.46 -0.91 -0.01 1579.55 1.00\n",
|
|
"draws_b[32] -1.09 0.31 -1.08 -1.57 -0.57 1433.98 1.00\n",
|
|
"draws_b[33] -1.19 0.29 -1.18 -1.67 -0.74 1349.72 1.00\n",
|
|
"draws_b[34] 0.57 0.25 0.57 0.19 1.01 1176.79 1.00\n",
|
|
"draws_b[35] 0.45 0.25 0.45 0.06 0.89 1484.68 1.00\n",
|
|
"draws_b[36] -0.20 0.26 -0.20 -0.65 0.20 1629.53 1.00\n",
|
|
"draws_b[37] 0.34 0.26 0.35 -0.13 0.74 1505.92 1.00\n",
|
|
"draws_b[38] 0.15 0.25 0.15 -0.23 0.59 1443.92 1.00\n",
|
|
"draws_b[39] 0.34 0.26 0.34 -0.07 0.79 1775.66 1.00\n",
|
|
"draws_b[40] 0.23 0.26 0.24 -0.17 0.68 1457.14 1.00\n",
|
|
"draws_b[41] -0.58 0.27 -0.57 -1.01 -0.15 1313.66 1.00\n",
|
|
"draws_b[42] -0.49 0.27 -0.48 -0.94 -0.04 1490.91 1.00\n",
|
|
"draws_b[43] -0.57 0.27 -0.56 -1.01 -0.11 1483.20 1.00\n",
|
|
"draws_b[44] 0.23 0.25 0.23 -0.15 0.66 981.31 1.00\n",
|
|
"draws_b[45] -0.31 0.26 -0.31 -0.70 0.15 1435.06 1.00\n",
|
|
"draws_b[46] -0.33 0.26 -0.33 -0.77 0.10 1224.35 1.00\n",
|
|
"draws_b[47] -0.62 0.27 -0.61 -1.05 -0.19 1631.06 1.00\n",
|
|
"draws_b[48] -0.48 0.27 -0.48 -0.92 -0.05 1106.08 1.00\n",
|
|
"draws_b[49] 1.16 0.26 1.16 0.74 1.58 1318.62 1.00\n",
|
|
"draws_b[50] -0.33 0.27 -0.33 -0.77 0.10 1336.15 1.00\n",
|
|
"draws_b[51] 0.35 0.25 0.34 -0.07 0.78 1383.73 1.00\n",
|
|
"draws_b[52] -0.32 0.26 -0.32 -0.76 0.09 1547.82 1.00\n",
|
|
"draws_b[53] 0.43 0.25 0.43 0.02 0.81 1491.33 1.00\n",
|
|
"draws_b[54] 0.10 0.26 0.10 -0.34 0.52 1530.15 1.00\n",
|
|
"draws_b[55] -0.58 0.28 -0.58 -1.07 -0.16 1241.15 1.00\n",
|
|
"draws_b[56] -0.48 0.28 -0.47 -0.95 -0.02 2004.81 1.00\n",
|
|
"draws_b[57] -0.45 0.28 -0.45 -0.91 -0.00 1592.65 1.00\n",
|
|
"draws_b[58] -0.22 0.29 -0.23 -0.68 0.24 1622.34 1.00\n",
|
|
"draws_b[59] -0.37 0.26 -0.37 -0.82 0.03 1411.27 1.00\n",
|
|
" draws_s[0] 1.59 0.41 1.58 0.99 2.31 91.66 1.01\n",
|
|
" draws_s[1] 2.00 0.42 1.97 1.32 2.67 92.51 1.01\n",
|
|
" draws_s[2] -0.81 0.47 -0.82 -1.60 -0.08 117.14 1.01\n",
|
|
" draws_s[3] 1.44 0.41 1.43 0.79 2.09 93.06 1.01\n",
|
|
" draws_s[4] 1.30 0.42 1.28 0.69 2.03 86.22 1.01\n",
|
|
" draws_s[5] -3.25 1.59 -2.94 -5.69 -1.06 801.05 1.00\n",
|
|
" draws_s[6] -0.52 0.45 -0.54 -1.30 0.16 108.95 1.01\n",
|
|
" draws_s[7] 1.96 0.42 1.93 1.29 2.62 94.58 1.01\n",
|
|
" draws_s[8] 2.37 0.43 2.35 1.69 3.07 98.58 1.01\n",
|
|
" draws_s[9] -3.24 1.55 -2.93 -5.45 -0.80 899.89 1.00\n",
|
|
"draws_s[10] 5.06 1.33 4.78 3.02 7.12 436.17 1.00\n",
|
|
"draws_s[11] -0.62 0.46 -0.66 -1.36 0.10 111.43 1.01\n",
|
|
"draws_s[12] -0.62 0.47 -0.65 -1.35 0.17 120.13 1.01\n",
|
|
"draws_s[13] -3.21 1.47 -2.95 -5.44 -0.99 1044.95 1.00\n",
|
|
"draws_s[14] -3.16 1.49 -2.89 -5.45 -0.95 954.54 1.00\n",
|
|
"draws_s[15] 5.03 1.29 4.83 3.12 7.02 603.50 1.00\n",
|
|
"draws_s[16] -0.83 0.48 -0.83 -1.62 -0.08 126.28 1.01\n",
|
|
"draws_s[17] -0.27 0.44 -0.28 -1.04 0.37 101.52 1.01\n",
|
|
"draws_s[18] -0.75 0.48 -0.77 -1.50 0.03 117.77 1.01\n",
|
|
"draws_s[19] -3.20 1.49 -2.97 -5.40 -0.87 866.56 1.00\n",
|
|
"draws_s[20] 1.54 0.41 1.52 0.90 2.21 90.24 1.01\n",
|
|
"draws_s[21] 0.01 0.43 -0.01 -0.73 0.65 96.74 1.01\n",
|
|
"draws_s[22] -1.05 0.51 -1.05 -1.86 -0.19 140.02 1.00\n",
|
|
"draws_s[23] -3.26 1.60 -2.94 -5.74 -0.85 956.12 1.00\n",
|
|
"draws_s[24] 0.65 0.41 0.63 0.02 1.36 92.65 1.01\n",
|
|
"draws_s[25] 5.03 1.29 4.79 3.24 7.05 699.49 1.00\n",
|
|
"draws_s[26] -0.10 0.43 -0.12 -0.75 0.62 97.31 1.01\n",
|
|
"draws_s[27] -0.64 0.46 -0.65 -1.39 0.10 116.88 1.01\n",
|
|
"draws_s[28] 0.40 0.42 0.38 -0.31 1.05 89.63 1.01\n",
|
|
"draws_s[29] 3.80 0.62 3.76 2.83 4.79 201.10 1.00\n",
|
|
"draws_s[30] 1.73 0.41 1.72 1.04 2.37 90.90 1.01\n",
|
|
"draws_s[31] -3.22 1.52 -2.91 -5.73 -1.11 1008.01 1.00\n",
|
|
"draws_s[32] -0.80 0.49 -0.80 -1.61 -0.05 125.12 1.01\n",
|
|
"draws_s[33] 1.91 0.42 1.89 1.23 2.57 97.31 1.01\n",
|
|
"draws_s[34] 1.09 0.41 1.06 0.41 1.74 90.62 1.01\n",
|
|
"draws_s[35] 1.04 0.41 1.01 0.35 1.69 90.74 1.01\n",
|
|
"draws_s[36] -1.53 0.61 -1.52 -2.50 -0.47 220.52 1.00\n",
|
|
"draws_s[37] -0.86 0.50 -0.87 -1.64 -0.06 124.72 1.01\n",
|
|
"draws_s[38] -1.41 0.58 -1.40 -2.40 -0.50 165.96 1.01\n",
|
|
"draws_s[39] -1.00 0.50 -1.00 -1.86 -0.26 134.26 1.01\n",
|
|
"draws_s[40] 0.43 0.42 0.40 -0.32 1.03 91.86 1.01\n",
|
|
"draws_s[41] -3.17 1.51 -2.88 -5.51 -0.94 802.41 1.00\n",
|
|
"draws_s[42] -3.17 1.51 -2.90 -5.58 -1.03 703.27 1.01\n",
|
|
"draws_s[43] -3.23 1.51 -2.94 -5.49 -1.06 1047.16 1.00\n",
|
|
"draws_s[44] -0.69 0.47 -0.70 -1.49 0.04 111.69 1.01\n",
|
|
"draws_s[45] -3.24 1.51 -2.96 -5.43 -0.87 809.34 1.00\n",
|
|
"draws_s[46] -1.14 0.52 -1.14 -1.96 -0.27 133.51 1.01\n",
|
|
"draws_s[47] 1.50 0.41 1.48 0.86 2.15 88.51 1.01\n",
|
|
"draws_s[48] -3.26 1.54 -2.99 -5.62 -0.90 776.73 1.00\n",
|
|
"draws_s[49] 2.81 0.45 2.79 2.06 3.48 110.23 1.01\n",
|
|
"draws_s[50] -1.50 0.60 -1.49 -2.48 -0.50 198.81 1.00\n",
|
|
"draws_s[51] 3.75 0.62 3.70 2.78 4.75 191.96 1.00\n",
|
|
"draws_s[52] 1.32 0.41 1.30 0.63 1.96 89.34 1.01\n",
|
|
"draws_s[53] -0.64 0.47 -0.64 -1.42 0.06 117.17 1.01\n",
|
|
"draws_s[54] -3.25 1.57 -2.90 -5.81 -1.11 894.61 1.00\n",
|
|
"draws_s[55] 5.01 1.26 4.76 3.33 6.99 449.27 1.00\n",
|
|
"draws_s[56] 5.05 1.32 4.82 3.05 6.98 642.23 1.00\n",
|
|
"draws_s[57] 4.99 1.27 4.72 3.12 7.02 663.59 1.00\n",
|
|
"draws_s[58] -3.14 1.44 -2.88 -5.37 -1.10 808.01 1.00\n",
|
|
"draws_s[59] 1.50 0.42 1.48 0.84 2.18 89.42 1.01\n",
|
|
" mu_tot -1.15 0.38 -1.13 -1.74 -0.53 76.53 1.01\n",
|
|
" sigma_b 0.65 0.07 0.64 0.53 0.77 935.24 1.00\n",
|
|
" sigma_s 2.75 0.37 2.70 2.18 3.36 425.85 1.00\n",
|
|
"\n",
|
|
"Number of divergences: 0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"rng_key = jrandom.PRNGKey(0)\n",
|
|
"rng_key, rng_key_ = jrandom.split(rng_key)\n",
|
|
"\n",
|
|
"kernel = NUTS(model)\n",
|
|
"num_samples = 2000\n",
|
|
"mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)\n",
|
|
"mcmc.run(\n",
|
|
" rng_key_, samples=thresholded_grid\n",
|
|
")\n",
|
|
"mcmc.print_summary()\n",
|
|
"samples_1 = mcmc.get_samples()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|