commit 4d6583e270e3483dd78bf905a51172280e747f5e Author: imperator Date: Mon Nov 20 21:56:00 2023 +0100 Initial commit diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..36f8a93 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.26.2 +scipy==1.11.4 diff --git a/sample_statistics_inference.ipynb b/sample_statistics_inference.ipynb new file mode 100644 index 0000000..3559b25 --- /dev/null +++ b/sample_statistics_inference.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a2b3d037-222d-4d71-aa2f-2ffce8b91e1b", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import scipy as sp\n", + "import scipy.stats as spstats\n", + "import scipy.optimize as spopt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "78044315-112d-4735-b06f-d9af3ab7c189", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_grid(sigma_s, sigma_b, mu_tot, grid_size=[60, 60]):\n", + " grid = (\n", + " mu_tot\n", + " + (np.random.normal(size=grid_size[0]) * sigma_s)[:, np.newaxis]\n", + " + (np.random.normal(size=grid_size[1]) * sigma_b)[np.newaxis, :]\n", + " + np.random.normal(size=grid_size)\n", + " )\n", + " grid = 1 * (grid > 0)\n", + " return grid" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7d093a51-a668-4c2d-a9cc-b900008073d3", + "metadata": {}, + "outputs": [], + "source": [ + "def grid_statistics(grid):\n", + " return np.array([\n", + " grid.var(axis=0, ddof=1).mean(),\n", + " grid.var(axis=1, ddof=1).mean(),\n", + " grid.mean()\n", + " ])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "118dfb52-d8ca-427f-8997-be5b511860a4", + "metadata": {}, + "outputs": [], + "source": [ + "def integrate(f, domain=[-20, 20], num_samples=10_000):\n", + " xs = np.linspace(*domain, num_samples)\n", + " ys = f(xs)\n", + " dx = (domain[1] - domain[0]) / num_samples\n", + " return (ys * dx).sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c155a99e-889a-40ee-9645-367c33bcef29", + "metadata": {}, + "outputs": [], + "source": [ + "def expected_variance(sigma, mu):\n", + " integrand = lambda x: spstats.norm.cdf(x) * (1 - spstats.norm.cdf(x)) * spstats.norm.pdf((x - mu) / sigma) / sigma\n", + " return integrate(integrand)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4e602070-c219-449e-b846-891752d09aa7", + "metadata": {}, + "outputs": [], + "source": [ + "def L(parameters, sample_statistics):\n", + " sigma_s, sigma_b, mu_tot = parameters\n", + "\n", + " expected_var_row = expected_variance(sigma=sigma_b / np.sqrt(1 + sigma_s ** 2), mu=mu_tot / np.sqrt(1 + sigma_s ** 2))\n", + " expected_var_column = expected_variance(sigma=sigma_s / np.sqrt(1 + sigma_b ** 2), mu=mu_tot / np.sqrt(1 + sigma_b ** 2))\n", + " expected_mean_tot = spstats.norm.cdf(mu_tot / np.sqrt(1 + sigma_s ** 2 + sigma_b ** 2))\n", + "\n", + " sample_var_row, sample_var_column, sample_mean_tot = sample_statistics\n", + " \n", + " return np.log(\n", + " (np.array([\n", + " expected_var_row - sample_var_row,\n", + " expected_var_column - sample_var_column,\n", + " expected_mean_tot - sample_mean_tot,\n", + " ]) ** 2.).sum()\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7a487055-7055-4908-8a8d-e25d25e809bf", + "metadata": {}, + "outputs": [], + "source": [ + "grid = generate_grid(sigma_s=2.4, sigma_b=0.6, mu_tot=-1.5)\n", + "sample_statistics = grid_statistics(grid)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "dc811186-f0bd-43ba-ac6c-0847b84f9ba1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 2.29071775, 0.57796239, -1.57644119])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inferred_parameters = spopt.minimize(\n", + " lambda parameters: L(parameters, sample_statistics),\n", + " x0=np.ones(3)\n", + ").x\n", + "\n", + "inferred_parameters" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}