diff --git a/6_iton_behavioral_models_neural.ipynb b/6_iton_behavioral_models_neural.ipynb
index c1d24ec..3562fd3 100644
--- a/6_iton_behavioral_models_neural.ipynb
+++ b/6_iton_behavioral_models_neural.ipynb
@@ -1,21 +1,5 @@
{
"cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "wd9no3mxCn7b",
- "metadata": {
- "id": "wd9no3mxCn7b"
- },
- "outputs": [],
- "source": [
- "!wget https://github.com/ManteLab/Iton_notebooks_public/raw/refs/heads/main/utils_ex6/utils.py -O utils.py\n",
- "!wget https://raw.githubusercontent.com/ManteLab/Iton_notebooks_public/refs/heads/main/data_ex6/neural_data.mat -O neural_data.mat\n",
- "!wget https://raw.githubusercontent.com/ManteLab/Iton_notebooks_public/refs/heads/main/data_ex6/dataset1.mat -O dataset1.mat\n",
- "!wget https://raw.githubusercontent.com/ManteLab/Iton_notebooks_public/refs/heads/main/data_ex6/dataset3.mat -O dataset3.mat\n",
- "!pip3 install --quiet hdf5storage ipympl"
- ]
- },
{
"cell_type": "markdown",
"id": "ec7211ca-a104-4c3d-b528-102841bfd937",
@@ -166,8 +150,7 @@
},
"outputs": [],
"source": [
- "import hdf5storage\n",
- "from utils import plot_model_free_analysis_conditions_vs_baseline\n",
+ "from utils import plot_model_free_analysis_conditions_vs_baseline, loadmat\n",
"\n",
"dataset_1 = hdf5storage.loadmat('dataset1.mat')\n",
"\n",
@@ -185,8 +168,7 @@
},
"outputs": [],
"source": [
- "import hdf5storage\n",
- "from utils import plot_model_free_analysis_conditions_vs_baseline\n",
+ "from utils import plot_model_free_analysis_conditions_vs_baseline, loadmat\n",
"\n",
"dataset_3 = hdf5storage.loadmat('dataset3.mat')\n",
"\n",
@@ -232,9 +214,9 @@
},
"outputs": [],
"source": [
- "import hdf5storage\n",
+ "from utils import loadmat\n",
"\n",
- "neural_data = hdf5storage.loadmat('neural_data.mat')"
+ "neural_data = loadmat('neural_data.mat')"
]
},
{
diff --git a/dataset1.mat b/dataset1.mat
new file mode 100644
index 0000000..7d253c5
Binary files /dev/null and b/dataset1.mat differ
diff --git a/dataset3.mat b/dataset3.mat
new file mode 100644
index 0000000..e522a59
Binary files /dev/null and b/dataset3.mat differ
diff --git a/neural_data.mat b/neural_data.mat
new file mode 100644
index 0000000..e3fc4b1
--- /dev/null
+++ b/neural_data.mat
@@ -0,0 +1,178 @@
+
+
+
+
+
+
+ polybox
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/requirements.txt b/requirements.txt
index 07807ba..0b60a40 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
scipy
numpy
matplotlib
-hdf5storage
ipympl
+mat73
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..2b5fa3b
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,781 @@
+from itertools import product
+import numpy as np
+from scipy.integrate import trapezoid
+import hdf5storage
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+from IPython.display import display
+from ipywidgets import interact, interact_manual, IntSlider, FloatSlider, IntRangeSlider, ToggleButton, ToggleButtons, Layout
+from scipy.io import loadmat as sp_loadmat
+from mat73 import loadmat as mat73_loadmat
+
+
+def in_colab():
+ """Check if the code is running in Google Colab."""
+ try:
+ import google.colab
+ return True
+ except ImportError:
+ return False
+
+
+is_colab = in_colab()
+continuous_update = not is_colab
+if is_colab:
+ from google.colab import output
+ output.enable_custom_widget_manager()
+
+
+def setup_matplotlib_magic():
+ get_ipython().run_line_magic('matplotlib', 'inline' if is_colab else 'widget')
+
+
+def draw_figure(fig):
+ if not is_colab:
+ fig.canvas.draw_idle()
+ else:
+ plt.show()
+
+
+def maybe_setup(setup_fun, state):
+ if not is_colab:
+ return
+ elif 'needs_setup' not in state:
+ state['needs_setup'] = True
+ else:
+ state.update(setup_fun())
+
+
+def loadmat(mat_file):
+ try:
+ return sp_loadmat(mat_file)
+ except Exception:
+ return mat73_loadmat(mat_file)
+
+
+def generate_sims(C, k, alpha, sigma_a, sigma_s, lambda_, n_sim=100, tau=100, dt_total=11 / 85):
+ dt = dt_total / tau
+
+ # discretize C
+ if isinstance(k, np.ndarray):
+ C_scaled = np.repeat(C * k[:, np.newaxis], tau, axis=1)
+ n_sim = len(k)
+ else:
+ C_scaled = np.repeat(C * k, tau)[np.newaxis, :]
+
+ T = C_scaled.shape[-1]
+
+ # noise terms
+ xiR = np.random.randn(n_sim) * alpha / k
+ xiL = np.random.randn(n_sim) * alpha / k
+ directional_noise = (
+ xiR[:, np.newaxis] * (C_scaled > 0) +
+ xiL[:, np.newaxis] * (C_scaled < 0)
+ )
+ dW = np.sqrt(dt) * np.random.randn(n_sim, T)
+ eta = 1 + np.random.randn(n_sim, T) * (sigma_s * np.sqrt(tau))
+
+ # accumulated evidence
+ a = np.zeros((n_sim, T + 1))
+ mE = np.zeros((n_sim, T + 1))
+ for t in range(T):
+ a[:, t + 1] = a[:, t] + (
+ directional_noise[:, t] * C_scaled[:, t] * (dt_total / tau) +
+ lambda_ * a[:, t] * (dt_total / tau) +
+ sigma_a * dW[:, t] +
+ eta[:, t] * C_scaled[:, t] * (dt_total / tau)
+ )
+
+ # momentary evidence
+ mE[:, t+1] = eta[:, t] * C_scaled[:, t] * (dt_total / tau) + lambda_ * a[:, t] * (dt_total / tau)
+
+ return a[:, 1:], mE, tau, dt
+
+
+def generate_sims_conditions(ks, directions, sim_parameters, num_sims_per_condition):
+ simulation_combinations = list(product(ks, directions))
+
+ a_all = []
+ mE_all = []
+ k_idx_all = []
+ direction_all = []
+ for idx, (k, direction) in enumerate(simulation_combinations):
+ C = sim_parameters['C'] * direction
+ dir_label = 1 if direction == 1 else 0
+
+ a_temp, mE_temp, tau, dt = generate_sims(**{
+ **sim_parameters,
+ 'C': C,
+ 'k': k,
+ 'n_sim': num_sims_per_condition
+ })
+
+ # subsample at every tau steps
+ a_sampled = a_temp[:, tau-1::tau]
+ mE_sampled = mE_temp[:, tau-1::tau] / dt
+
+ a_all.append(a_sampled)
+ mE_all.append(mE_sampled)
+ k_idx_all.extend([k] * num_sims_per_condition)
+ direction_all.extend([dir_label] * num_sims_per_condition)
+
+ a_all = np.vstack(a_all)
+ mE_all = np.vstack(mE_all)
+ k_idx_all = np.array(k_idx_all)
+ direction_all = np.array(direction_all)
+
+ choices = (a_all > 0).astype(int) # 1 is right, 0 is left
+ is_correct = (choices == direction_all[:, np.newaxis]).astype(int)
+
+ time = np.arange(len(C))
+
+ return time, a_all, mE_all, k_idx_all, choices, is_correct
+
+
+def plot_sims(C_size=11, num_sims=30 if not is_colab else 5):
+ setup_matplotlib_magic()
+
+ def setup():
+ fig, axes = plt.subplots(figsize=(6.5, 5))
+
+ evidence_line = axes.plot([], [], color='C2', alpha=1)[0]
+ sim_lines = []
+ for i in range(num_sims):
+ sim_line = axes.plot([], [], color='C0', alpha=0.3)[0]
+ sim_lines += [sim_line]
+
+ axes.set(
+ title=f"{num_sims} Simulations",
+ ylabel="value",
+ xlabel="time $t$",
+ xlim=(0, 11),
+ ylim=(-1.5, 1.5)
+ )
+
+ plt.axhline(0., color='black', alpha=0.3)
+ plt.tight_layout()
+
+ legend_elements = [
+ Line2D([], [], color='C2', label='evidence pulse'),
+ Line2D([], [], color='C0', label='accumulator $a$ (decision: right)'),
+ Line2D([], [], color='C1', label='accumulator $a$ (decision: left)')
+ ]
+ axes.legend(handles=legend_elements, loc='upper right')
+
+ return {'fig': fig, 'axes': axes, 'evidence_line': evidence_line, 'sim_lines': sim_lines}
+
+ state = setup()
+ state['random_seed'] = 42
+
+ def update_plot(C_dir, C, k, alpha, sigma_a, sigma_s, lambda_, fixed_noise):
+ maybe_setup(setup, state)
+
+ if fixed_noise == 'redraw noise':
+ state['random_seed'] = np.random.randint(0, 2**32)
+ np.random.seed(state['random_seed'])
+
+ C = np.concatenate([np.zeros(C[0]), np.ones(C[1] - C[0]), np.zeros(C_size - C[1])])
+ C *= 1 if C_dir == 'pulse right' else -1
+
+ sims, *_ = generate_sims(C, k, alpha, sigma_a, sigma_s, lambda_, n_sim=num_sims)
+
+ for sim, sim_line in zip(sims, state['sim_lines']):
+ sim_line.set_data(np.linspace(0, len(C), len(sim)), sim)
+ sim_line.set_color('C0' if sim[-1] > 0 else 'C1')
+
+ state['evidence_line'].set_data(np.linspace(0., len(C), len(C) * 1_000), np.repeat(C, 1_000) * k)
+
+ draw_figure(state['fig'])
+
+ style = {'description_width': '150px'}
+ layout = Layout(width='600px')
+ sliders = {
+ 'C_dir': ToggleButtons(options=['pulse left', 'pulse right'], value='pulse right', description=' '),
+ 'C': IntRangeSlider(min=0, max=C_size, value=[3, 7], description='evidence pulse timing', style=style, layout=layout, continuous_update=continuous_update),
+ 'k': FloatSlider(min=1e-6, max=1., step=0.01, value=0.5, description='coherence', style=style, layout=layout, continuous_update=continuous_update),
+ 'sigma_s': FloatSlider(min=0, max=3, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout, continuous_update=continuous_update),
+ 'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
+ 'sigma_a': FloatSlider(min=0, max=1, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
+ 'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout, continuous_update=continuous_update),
+ 'fixed_noise': ToggleButtons(options=['fix noise', 'redraw noise'], value='fix noise', description=' '),
+ }
+
+ interact(update_plot, **sliders)
+
+
+def update_errorbar(err_container, x, y, yerr):
+ err_container.lines[0].set_data(x, y)
+ linecol = err_container.lines[2][0]
+
+ segments = []
+ for xi, yi, yerri in zip(x, y, yerr):
+ segments.append([[xi, yi - yerri], [xi, yi + yerri]])
+
+ linecol.set_segments(segments)
+
+
+def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000):
+ setup_matplotlib_magic()
+
+ def setup():
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True)
+
+ accuracy_lines = [axes[0].errorbar([], [], yerr=[], label=f'$k = {k}$') for k in ks]
+ kernel_lines = [axes[1].plot([], [], label=f'$k = {k}$')[0] for k in ks]
+
+ axes[0].set(
+ title="accuracy",
+ xlabel="$t$",
+ xlim=(0, len(C) - 1),
+ ylim=(0, 1)
+ )
+ axes[0].legend(loc='lower right', fontsize='small')
+
+ axes[1].set(
+ title="psychophysical kernel",
+ xlabel="$t$",
+ ylim=(-3, 3)
+ )
+ axes[1].legend(loc='lower left', fontsize='small')
+
+ fig.tight_layout()
+
+ return {'fig': fig, 'axes': axes, 'accuracy_lines': accuracy_lines, 'kernel_lines': kernel_lines}
+
+ state = setup() if not is_colab else {'needs_setup': True}
+
+ def update_plot(sigma_s, alpha, sigma_a, lambda_):
+ maybe_setup(setup, state)
+
+ sim_parameters = {
+ 'C': C,
+ 'sigma_s': sigma_s,
+ 'alpha': alpha,
+ 'sigma_a': sigma_a,
+ 'lambda_': lambda_
+ }
+
+ directions = [1, -1]
+ time, a_all, mE_all, k_idx_all, choices, is_correct = generate_sims_conditions(
+ ks, directions, sim_parameters, num_sims_per_condition
+ )
+
+ for i, k in enumerate(ks):
+ mask = (k_idx_all == k)
+ is_corr_k = is_correct[mask, :]
+ perf = is_corr_k.mean(axis=0)
+ ci95 = 1.96 * is_corr_k.std(axis=0) / np.sqrt(mask.sum())
+
+ update_errorbar(state['accuracy_lines'][i], time, perf, yerr=ci95)
+
+ psy_kernel = (
+ mE_all[ (choices[:, -1] == 1) & mask ].mean(axis=0) -
+ mE_all[ (choices[:, -1] != 1) & mask ].mean(axis=0)
+ )
+
+ state['kernel_lines'][i].set_data(time, psy_kernel)
+
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+
+ style = {'description_width': '150px'}
+ layout = Layout(width='600px')
+ sliders = {
+ 'sigma_s': FloatSlider(min=0, max=5, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout),
+ 'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout),
+ 'sigma_a': FloatSlider(min=0, max=2, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout),
+ 'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout)
+ }
+
+ interact_manual.options(manual_name='run simulations')(
+ update_plot,
+ **sliders
+ )
+
+
+def model_free_analysis(dataset):
+ is_correct = dataset['choices'] == dataset['direction'].flatten()
+ time = np.arange(dataset['a'].shape[1])
+
+ perfs = []
+ ci95s = []
+ psy_kernels = []
+ for k_idx in [1, 2, 3]:
+ mask = (dataset['kIdx'].flatten() == k_idx)
+ is_corr_k = is_correct[:, mask]
+ perf = is_corr_k.mean(axis=1)
+ ci95 = 1.96 * is_corr_k.std(axis=1) / np.sqrt(mask.sum())
+
+ psy_kernel = (
+ dataset['mE'][ (dataset['choices'][-1, :] == 1) & mask ].mean(axis=0) -
+ dataset['mE'][ (dataset['choices'][-1, :] != 1) & mask ].mean(axis=0)
+ )
+
+ perfs += [perf]
+ ci95s += [ci95]
+ psy_kernels += [psy_kernel]
+
+ return time, perfs, ci95s, psy_kernels
+
+
+def plot_model_free_analysis_conditions_vs_baseline(baseline_data, num_sims_per_condition=2_000):
+ setup_matplotlib_magic()
+
+ C = np.concatenate(([0], np.ones(10)))
+ ks = [0.2, 0.4, 0.8]
+
+ def setup():
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True)
+
+ accuracy_lines = [axes[0].errorbar([], [], yerr=[], label=f'$k = {k}$') for k in ks]
+ kernel_lines = [axes[1].plot([], [], label=f'$k = {k}$')[0] for k in ks]
+
+ axes[0].set(
+ title="accuracy",
+ xlabel="$t$",
+ xlim=(0, len(C) - 1),
+ ylim=(0, 1)
+ )
+
+ axes[1].set(
+ title="psychophysical kernel",
+ xlabel="$t$",
+ ylim=(-3, 3)
+ )
+
+ time, perfs, ci95s, psy_kernels = model_free_analysis(baseline_data)
+ for i, (perf, ci95, psy_kernel) in enumerate(zip(perfs, ci95s, psy_kernels, strict=True)):
+ axes[0].errorbar(time, perf, yerr=ci95, color=f'C{i}', label=f'$k = {ks[i]}$ (baseline)', linestyle='--', alpha=0.3)
+ axes[1].plot(time, psy_kernel, color=f'C{i}', label=f'$k = {ks[i]}$ (baseline)', linestyle='--', alpha=0.3)
+
+ axes[0].legend(loc='lower right', fontsize='small')
+ axes[1].legend(loc='lower left', fontsize='small')
+
+ fig.tight_layout()
+
+ return {'fig': fig, 'axes': axes, 'accuracy_lines': accuracy_lines, 'kernel_lines': kernel_lines}
+
+ state = setup() if not is_colab else {'needs_setup': True}
+
+ def update_plot(sigma_s, alpha, sigma_a, lambda_):
+ maybe_setup(setup, state)
+
+ sim_parameters = {
+ 'C': C,
+ 'sigma_s': sigma_s,
+ 'alpha': alpha,
+ 'sigma_a': sigma_a,
+ 'lambda_': lambda_
+ }
+
+ directions = [1, -1]
+ time, a_all, mE_all, k_idx_all, choices, is_correct = generate_sims_conditions(
+ ks, directions, sim_parameters, num_sims_per_condition
+ )
+
+ for i, k in enumerate(ks):
+ mask = (k_idx_all == k)
+ is_corr_k = is_correct[mask, :]
+ perf = is_corr_k.mean(axis=0)
+ ci95 = 1.96 * is_corr_k.std(axis=0) / np.sqrt(mask.sum())
+
+ update_errorbar(state['accuracy_lines'][i], time, perf, yerr=ci95)
+
+ psy_kernel = (
+ mE_all[ (choices[:, -1] == 1) & mask ].mean(axis=0) -
+ mE_all[ (choices[:, -1] != 1) & mask ].mean(axis=0)
+ )
+
+ state['kernel_lines'][i].set_data(time, psy_kernel)
+
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+
+ style = {'description_width': '150px'}
+ layout = Layout(width='600px')
+ sliders = {
+ 'sigma_s': FloatSlider(min=0, max=5, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout),
+ 'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout),
+ 'sigma_a': FloatSlider(min=0, max=2, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout),
+ 'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout)
+ }
+
+ interact_manual.options(manual_name='run simulations')(
+ update_plot,
+ **sliders
+ )
+
+
+def bin_spikes(raw_spike_matrix, bin_size=50):
+ num_bins = raw_spike_matrix.shape[1] // bin_size
+
+ truncated_raw_spike_matrix = raw_spike_matrix[:, :num_bins * bin_size, :]
+ binned_spike_matrix = truncated_raw_spike_matrix.reshape([
+ truncated_raw_spike_matrix.shape[0],
+ num_bins,
+ -1,
+ truncated_raw_spike_matrix.shape[2]
+ ]).sum(axis=2)
+
+ return binned_spike_matrix
+
+
+def get_binned_spike_matrix(mat_data):
+ raw_spike_matrix = mat_data['RawSpikeMatrix1'][:, 149:1000, :]
+ binned_spike_matrix = bin_spikes(raw_spike_matrix)
+ binned_spike_matrix = np.sqrt(binned_spike_matrix)
+ time = np.arange(binned_spike_matrix.shape[1]) * 50
+ return time, binned_spike_matrix
+
+
+def plot_single_neuron(mat_data):
+ setup_matplotlib_magic()
+
+ time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
+
+ def setup():
+ fig, axes = plt.subplots(figsize=(6.5, 4.5))
+
+ neuron_line = axes.plot([], [])[0]
+
+ axes.set(
+ ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
+ xlabel='time [ms]',
+ xlim=(0, 800)
+ )
+
+ return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line}
+
+ state = setup()
+
+ def update_plot(neuron_idx):
+ maybe_setup(setup, state)
+
+ state['neuron_line'].set_data(time, binned_spike_matrix.mean(axis=0)[:, neuron_idx])
+
+ state['axes'].relim()
+ state['axes'].autoscale(axis='y')
+ state['axes'].set_title(f'Neuron #{neuron_idx}', fontsize='small')
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+ sliders = {
+ 'neuron_idx': IntSlider(min=0, max=binned_spike_matrix.shape[2] - 1, description='neuron #', layout=Layout(width='800px'), continuous_update=continuous_update)
+ }
+
+ interact(update_plot, **sliders)
+
+
+def plot_neuron_by_choice(mat_data):
+ setup_matplotlib_magic()
+
+ time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
+
+ correct_trials_mask = (mat_data['targ_cho'].flatten() == mat_data['targ_cor'].flatten())
+ right_choice = (mat_data['targ_cho'].flatten() == 1)
+
+ def setup():
+ fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
+
+ choices = ['right choice', 'left choice']
+ correct_lines = []
+ for choice in choices:
+ correct_line = axes[0].plot([], [], label=choice)[0]
+ correct_lines += [correct_line]
+
+ incorrect_lines = []
+ for choice in choices:
+ incorrect_line = axes[1].plot([], [], label=choice)[0]
+ incorrect_lines += [incorrect_line]
+
+ axes[0].set(
+ title='correct trials',
+ ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
+ xlabel='time [ms]',
+ xlim=(0, 800)
+ )
+ axes[1].set(
+ title='incorrect trials',
+ xlabel='time [ms]'
+ )
+ axes[0].legend(loc='upper right')
+ axes[1].legend(loc='upper right')
+
+ return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
+
+ state = setup()
+
+ def update_plot(neuron_idx):
+ maybe_setup(setup, state)
+
+ state['correct_lines'][0].set_data(time, binned_spike_matrix[correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx])
+ state['correct_lines'][1].set_data(time, binned_spike_matrix[correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx])
+ state['incorrect_lines'][0].set_data(time, binned_spike_matrix[~correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx])
+ state['incorrect_lines'][1].set_data(time, binned_spike_matrix[~correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx])
+
+ state['axes'][0].relim()
+ state['axes'][1].relim()
+ state['axes'][0].autoscale(axis='y')
+ state['axes'][1].autoscale(axis='y')
+ state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small')
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+ sliders = {
+ 'neuron_idx': IntSlider(min=0, max=binned_spike_matrix.shape[2] - 1, description='neuron #', layout=Layout(width='800px'), continuous_update=continuous_update)
+ }
+
+ interact(update_plot, **sliders)
+
+
+def plot_neuron_by_coherence(mat_data):
+ setup_matplotlib_magic()
+
+ time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
+
+ correct_trials_mask = (mat_data['targ_cho'].flatten() == mat_data['targ_cor'].flatten())
+ coherences = np.sort(
+ np.unique(mat_data['dot_coh'])
+ )
+ coherences = coherences[[0, 3, 5]]
+
+ def setup():
+ fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
+
+ choices = ['right choice', 'left choice']
+ correct_lines = []
+ for coherence in coherences:
+ correct_line = axes[0].plot([], [], label=f'{coherence = :.1%}')[0]
+ correct_lines += [correct_line]
+
+ incorrect_lines = []
+ for coherence in coherences:
+ incorrect_line = axes[1].plot([], [], label=f'{coherence = :.1%}')[0]
+ incorrect_lines += [incorrect_line]
+
+ axes[0].set(
+ title='correct trials',
+ ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
+ xlabel='time [ms]',
+ xlim=(0, 800)
+ )
+ axes[1].set(
+ title='incorrect trials',
+ xlabel='time [ms]'
+ )
+ axes[0].legend(loc='upper right')
+ axes[1].legend(loc='upper right')
+
+ return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
+
+ state = setup()
+
+
+ def update_plot(neuron_idx):
+ maybe_setup(setup, state)
+
+ for i, coherence in enumerate(coherences):
+ coherence_mask = (mat_data['dot_coh'].flatten() == coherence)
+ state['correct_lines'][i].set_data(time, binned_spike_matrix[correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx])
+ state['incorrect_lines'][i].set_data(time, binned_spike_matrix[~correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx])
+
+ state['axes'][0].relim()
+ state['axes'][1].relim()
+ state['axes'][0].autoscale(axis='y')
+ state['axes'][1].autoscale(axis='y')
+ state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small')
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+ sliders = {
+ 'neuron_idx': IntSlider(min=0, max=binned_spike_matrix.shape[2] - 1, description='neuron #', layout=Layout(width='800px'), continuous_update=continuous_update)
+ }
+
+ interact(update_plot, **sliders)
+
+
+def calculate_deltas(mat_data):
+ time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
+ right_choice = (mat_data['targ_cho'].flatten() == 1)
+
+ mean_spikes_right = binned_spike_matrix[right_choice].mean(axis=0)
+ mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0)
+
+ deltas = (
+ trapezoid(mean_spikes_right, axis=0) -
+ trapezoid(mean_spikes_left, axis=0)
+ )
+
+ return deltas
+
+
+def plot_deltas(deltas):
+ setup_matplotlib_magic()
+
+ fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
+
+ axes[0].hist(deltas, bins=16, range=(-4, 4))
+ axes[1].hist(np.abs(deltas), bins=15, range=(0, 4.2))
+
+ axes[0].set(
+ ylabel='counts',
+ xlabel=r'$\Delta$'
+ )
+ axes[1].set(
+ xlabel=r'|$\Delta$|'
+ )
+ plt.tight_layout()
+
+
+def plot_aggregated_neurons(mat_data):
+ setup_matplotlib_magic()
+
+ time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
+ right_choice = (mat_data['targ_cho'].flatten() == 1)
+ mean_spikes_right = binned_spike_matrix[right_choice].mean(axis=0)
+ mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0)
+
+ deltas = calculate_deltas(mat_data)
+
+ def setup():
+ fig, axes = plt.subplots()
+
+ lines = [
+ axes.plot([], [], label='right choice')[0],
+ axes.plot([], [], label='left choice')[0]
+ ]
+
+ axes.set(
+ ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
+ xlabel='time [ms]',
+ xlim=(0, 800)
+ )
+ axes.legend(loc='upper right')
+
+ return {'fig': fig, 'axes': axes, 'lines': lines}
+
+ state = setup()
+
+ def update_plot(delta_threshold):
+ maybe_setup(setup, state)
+
+ state['lines'][0].set_data(time, (mean_spikes_right * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1))
+ state['lines'][1].set_data(time, (mean_spikes_left * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1))
+
+ state['axes'].relim()
+ state['axes'].autoscale(axis='y')
+ state['axes'].set(
+ title=f'|Δ| > {delta_threshold:.2f}'
+ )
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+ sliders = {
+ 'delta_threshold': FloatSlider(min=0, max=np.abs(deltas).max() - 1e-3, description='threshold |Δ|', layout=Layout(width='800px'), continuous_update=continuous_update)
+ }
+
+ interact(update_plot, **sliders)
+
+
+def simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_):
+ dot_coh = mat_data['dot_coh'].flatten()
+ dot_dir = mat_data['dot_dir'].flatten()
+ targ_cor = mat_data['targ_cor'].flatten()
+
+ C = np.array([0] + [1]*16)
+
+ dot_coh[dot_coh == 0] = 1e-12
+ k = np.unique(dot_coh)
+
+ # map directions: 0 -> 1 (right), 180 -> -1 (left)
+ d = np.copy(dot_dir)
+ d[dot_dir == 0] = 1
+ d[dot_dir == 180] = -1
+
+ a, _, tau, dt = generate_sims(np.outer(d, C), dot_coh, alpha, sigma_a, sigma_s, lambda_)
+ a = a[:, tau-1::tau]
+
+ # determine choices and correctness
+ cho = (a[:, -1] > 0).astype(int)
+ cho[cho == 0] = 2 # 2 is left, 1 is right
+ isCorr = cho == targ_cor
+
+ # separate correct and incorrect trials
+ a_Cor = a[isCorr, :]
+ d_Cor = d[isCorr]
+ cho_Cor = cho[isCorr]
+ coh_Cor = dot_coh[isCorr]
+
+ a_Inc = a[~isCorr, :]
+ d_Inc = d[~isCorr]
+ cho_Inc = cho[~isCorr]
+ coh_Inc = dot_coh[~isCorr]
+
+ # plot average accumulation for correct trials by direction
+ unq_dir = np.unique(d)
+
+ means_a = []
+ for dir_ in unq_dir:
+ mean_a = np.mean(a_Cor[d_Cor == dir_, :], axis=0)
+ means_a += [mean_a]
+
+ return means_a
+
+
+def plot_sims_conditions(mat_data):
+ setup_matplotlib_magic()
+
+ def setup():
+ fig, axes = plt.subplots(figsize=(6.5, 5))
+
+ evidence_line = axes.plot([], [], color='C2', alpha=1)[0]
+ sim_lines = []
+ for choice in ['right choice', 'left choice']:
+ sim_line = axes.plot([], [], label=choice)[0]
+ sim_lines += [sim_line]
+
+ axes.set(
+ ylabel="mean $a$",
+ xlabel="time $t$",
+ xlim=(0, 800),
+ ylim=(-0.5, .5)
+ )
+
+ axes.legend(loc='upper right')
+ plt.tight_layout()
+
+ return {'fig': fig, 'axes': axes, 'sim_lines': sim_lines}
+
+ state = setup()
+ state['random_seed'] = 42
+
+ def update_plot(alpha, sigma_a, sigma_s, lambda_, fixed_noise):
+ maybe_setup(setup, state)
+
+ if fixed_noise == 'redraw noise':
+ state['random_seed'] = np.random.randint(0, 2**32)
+ np.random.seed(state['random_seed'])
+
+ means_a = simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_)
+
+ for mean_a, line in zip(means_a[::-1], state['sim_lines'], strict=True):
+ line.set_data(np.arange(len(mean_a)) * 50, mean_a)
+
+ state['axes'].relim()
+ state['axes'].autoscale(axis='y')
+ state['fig'].tight_layout()
+ draw_figure(state['fig'])
+
+ style = {'description_width': '150px'}
+ layout = Layout(width='600px')
+ sliders = {
+ 'sigma_s': FloatSlider(min=0, max=3, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout, continuous_update=continuous_update),
+ 'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
+ 'sigma_a': FloatSlider(min=0, max=1, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
+ 'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout, continuous_update=continuous_update),
+ 'fixed_noise': ToggleButtons(options=['fix noise', 'redraw noise'], value='fix noise', description=' '),
+ }
+
+ interact(update_plot, **sliders)