You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

822 lines
28 KiB
Python

from itertools import product
import numpy as np
from scipy.integrate import trapezoid
from scipy.stats import sem
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from IPython.display import display
from ipywidgets import interact, interact_manual, IntSlider, FloatSlider, IntRangeSlider, ToggleButton, ToggleButtons, Layout
from scipy.io import loadmat as sp_loadmat
from mat73 import loadmat as mat73_loadmat
def in_colab():
"""Check if the code is running in Google Colab."""
try:
import google.colab
return True
except ImportError:
return False
is_colab = in_colab()
continuous_update = not is_colab
if is_colab:
from google.colab import output
output.enable_custom_widget_manager()
def setup_matplotlib_magic():
get_ipython().run_line_magic('matplotlib', 'inline' if is_colab else 'widget')
def draw_figure(fig):
if not is_colab:
fig.canvas.draw_idle()
else:
plt.show()
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
if len(err_container.lines[1]) == 2:
lower_caps, upper_caps = err_container.lines[1]
lower_caps.set_data(x, y - yerr)
upper_caps.set_data(x, y + yerr)
def maybe_setup(setup_fun, state):
if not is_colab:
return
elif 'needs_setup' not in state:
state['needs_setup'] = True
else:
state.update(setup_fun())
def loadmat(mat_file):
try:
return sp_loadmat(mat_file)
except Exception:
return mat73_loadmat(mat_file)
def generate_sims(C, k, alpha, sigma_a, sigma_s, lambda_, n_sim=100, tau=100, dt_total=11 / 85):
dt = dt_total / tau
# discretize C
if isinstance(k, np.ndarray):
C_scaled = np.repeat(C * k[:, np.newaxis], tau, axis=1)
n_sim = len(k)
else:
C_scaled = np.repeat(C * k, tau)[np.newaxis, :]
T = C_scaled.shape[-1]
# noise terms
xiR = np.random.randn(n_sim) * alpha / k
xiL = np.random.randn(n_sim) * alpha / k
directional_noise = (
xiR[:, np.newaxis] * (C_scaled > 0) +
xiL[:, np.newaxis] * (C_scaled < 0)
)
dW = np.sqrt(dt) * np.random.randn(n_sim, T)
eta = 1 + np.random.randn(n_sim, T) * (sigma_s * np.sqrt(tau))
# accumulated evidence
a = np.zeros((n_sim, T + 1))
mE = np.zeros((n_sim, T + 1))
for t in range(T):
a[:, t + 1] = a[:, t] + (
directional_noise[:, t] * C_scaled[:, t] * (dt_total / tau) +
lambda_ * a[:, t] * (dt_total / tau) +
sigma_a * dW[:, t] +
eta[:, t] * C_scaled[:, t] * (dt_total / tau)
)
# momentary evidence
mE[:, t+1] = eta[:, t] * C_scaled[:, t] * (dt_total / tau) + lambda_ * a[:, t] * (dt_total / tau)
return a[:, 1:], mE, tau, dt
def generate_sims_conditions(ks, directions, sim_parameters, num_sims_per_condition):
simulation_combinations = list(product(ks, directions))
a_all = []
mE_all = []
k_idx_all = []
direction_all = []
for idx, (k, direction) in enumerate(simulation_combinations):
C = sim_parameters['C'] * direction
dir_label = 1 if direction == 1 else 0
a_temp, mE_temp, tau, dt = generate_sims(**{
**sim_parameters,
'C': C,
'k': k,
'n_sim': num_sims_per_condition
})
# subsample at every tau steps
a_sampled = a_temp[:, tau-1::tau]
mE_sampled = mE_temp[:, tau-1::tau] / dt
a_all.append(a_sampled)
mE_all.append(mE_sampled)
k_idx_all.extend([k] * num_sims_per_condition)
direction_all.extend([dir_label] * num_sims_per_condition)
a_all = np.vstack(a_all)
mE_all = np.vstack(mE_all)
k_idx_all = np.array(k_idx_all)
direction_all = np.array(direction_all)
choices = (a_all > 0).astype(int) # 1 is right, 0 is left
is_correct = (choices == direction_all[:, np.newaxis]).astype(int)
time = np.arange(len(C))
return time, a_all, mE_all, k_idx_all, choices, is_correct
def plot_sims(C_size=11, num_sims=30 if not is_colab else 5):
setup_matplotlib_magic()
def setup():
fig, axes = plt.subplots(figsize=(6.5, 5))
evidence_line = axes.plot([], [], color='C2', alpha=1)[0]
sim_lines = []
for i in range(num_sims):
sim_line = axes.plot([], [], color='C0', alpha=0.3)[0]
sim_lines += [sim_line]
axes.set(
title=f"{num_sims} Simulations",
ylabel="value",
xlabel="time $t$",
xlim=(0, 11),
ylim=(-1.5, 1.5)
)
plt.axhline(0., color='black', alpha=0.3)
plt.tight_layout()
legend_elements = [
Line2D([], [], color='C2', label='evidence pulse'),
Line2D([], [], color='C0', label='accumulator $a$ (decision: right)'),
Line2D([], [], color='C1', label='accumulator $a$ (decision: left)')
]
axes.legend(handles=legend_elements, loc='upper right')
return {'fig': fig, 'axes': axes, 'evidence_line': evidence_line, 'sim_lines': sim_lines}
state = setup()
state['random_seed'] = 42
def update_plot(C_dir, C, k, alpha, sigma_a, sigma_s, lambda_, fixed_noise):
maybe_setup(setup, state)
if fixed_noise == 'redraw noise':
state['random_seed'] = np.random.randint(0, 2**32)
np.random.seed(state['random_seed'])
C = np.concatenate([np.zeros(C[0]), np.ones(C[1] - C[0]), np.zeros(C_size - C[1])])
C *= 1 if C_dir == 'pulse right' else -1
sims, *_ = generate_sims(C, k, alpha, sigma_a, sigma_s, lambda_, n_sim=num_sims)
for sim, sim_line in zip(sims, state['sim_lines']):
sim_line.set_data(np.linspace(0, len(C), len(sim)), sim)
sim_line.set_color('C0' if sim[-1] > 0 else 'C1')
state['evidence_line'].set_data(np.linspace(0., len(C), len(C) * 1_000), np.repeat(C, 1_000) * k)
draw_figure(state['fig'])
style = {'description_width': '150px'}
layout = Layout(width='600px')
sliders = {
'C_dir': ToggleButtons(options=['pulse left', 'pulse right'], value='pulse right', description=' '),
'C': IntRangeSlider(min=0, max=C_size, value=[3, 7], description='evidence pulse timing', style=style, layout=layout, continuous_update=continuous_update),
'k': FloatSlider(min=1e-6, max=1., step=0.01, value=0.5, description='coherence', style=style, layout=layout, continuous_update=continuous_update),
'sigma_s': FloatSlider(min=0, max=3, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout, continuous_update=continuous_update),
'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
'sigma_a': FloatSlider(min=0, max=1, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout, continuous_update=continuous_update),
'fixed_noise': ToggleButtons(options=['fix noise', 'redraw noise'], value='fix noise', description=' '),
}
interact(update_plot, **sliders)
def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000):
setup_matplotlib_magic()
def setup():
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True)
accuracy_lines = [axes[0].errorbar([], [], yerr=[], label=f'$k = {k}$') for k in ks]
kernel_lines = [axes[1].plot([], [], label=f'$k = {k}$')[0] for k in ks]
axes[0].set(
title="accuracy",
xlabel="$t$",
xlim=(0, len(C) - 1),
ylim=(0, 1)
)
axes[0].legend(loc='lower right', fontsize='small')
axes[1].set(
title="psychophysical kernel",
xlabel="$t$",
ylim=(-3, 3)
)
axes[1].legend(loc='lower left', fontsize='small')
fig.tight_layout()
return {'fig': fig, 'axes': axes, 'accuracy_lines': accuracy_lines, 'kernel_lines': kernel_lines}
state = setup() if not is_colab else {'needs_setup': True}
def update_plot(sigma_s, alpha, sigma_a, lambda_):
maybe_setup(setup, state)
sim_parameters = {
'C': C,
'sigma_s': sigma_s,
'alpha': alpha,
'sigma_a': sigma_a,
'lambda_': lambda_
}
directions = [1, -1]
time, a_all, mE_all, k_idx_all, choices, is_correct = generate_sims_conditions(
ks, directions, sim_parameters, num_sims_per_condition
)
for i, k in enumerate(ks):
mask = (k_idx_all == k)
is_corr_k = is_correct[mask, :]
perf = is_corr_k.mean(axis=0)
ci95 = 1.96 * is_corr_k.std(axis=0) / np.sqrt(mask.sum())
update_errorbar(state['accuracy_lines'][i], time, perf, yerr=ci95)
psy_kernel = (
mE_all[ (choices[:, -1] == 1) & mask ].mean(axis=0) -
mE_all[ (choices[:, -1] != 1) & mask ].mean(axis=0)
)
state['kernel_lines'][i].set_data(time, psy_kernel)
state['fig'].tight_layout()
draw_figure(state['fig'])
style = {'description_width': '150px'}
layout = Layout(width='600px')
sliders = {
'sigma_s': FloatSlider(min=0, max=5, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout),
'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout),
'sigma_a': FloatSlider(min=0, max=2, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout),
'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout)
}
interact_manual.options(manual_name='run simulations')(
update_plot,
**sliders
)
def model_free_analysis(dataset):
is_correct = dataset['choices'] == dataset['direction'].flatten()
time = np.arange(dataset['a'].shape[1])
perfs = []
ci95s = []
psy_kernels = []
for k_idx in [1, 2, 3]:
mask = (dataset['kIdx'].flatten() == k_idx)
is_corr_k = is_correct[:, mask]
perf = is_corr_k.mean(axis=1)
ci95 = 1.96 * is_corr_k.std(axis=1) / np.sqrt(mask.sum())
psy_kernel = (
dataset['mE'][ (dataset['choices'][-1, :] == 1) & mask ].mean(axis=0) -
dataset['mE'][ (dataset['choices'][-1, :] != 1) & mask ].mean(axis=0)
)
perfs += [perf]
ci95s += [ci95]
psy_kernels += [psy_kernel]
return time, perfs, ci95s, psy_kernels
def plot_model_free_analysis_conditions_vs_baseline(baseline_data, num_sims_per_condition=2_000):
setup_matplotlib_magic()
C = np.concatenate(([0], np.ones(10)))
ks = [0.2, 0.4, 0.8]
def setup():
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True)
accuracy_lines = [axes[0].errorbar([], [], yerr=[], label=f'$k = {k}$') for k in ks]
kernel_lines = [axes[1].plot([], [], label=f'$k = {k}$')[0] for k in ks]
axes[0].set(
title="accuracy",
xlabel="$t$",
xlim=(0, len(C) - 1),
ylim=(0, 1)
)
axes[1].set(
title="psychophysical kernel",
xlabel="$t$",
ylim=(-3, 3)
)
time, perfs, ci95s, psy_kernels = model_free_analysis(baseline_data)
for i, (perf, ci95, psy_kernel) in enumerate(zip(perfs, ci95s, psy_kernels, strict=True)):
axes[0].errorbar(time, perf, yerr=ci95, color=f'C{i}', label=f'$k = {ks[i]}$ (baseline)', linestyle='--', alpha=0.3)
axes[1].plot(time, psy_kernel, color=f'C{i}', label=f'$k = {ks[i]}$ (baseline)', linestyle='--', alpha=0.3)
axes[0].legend(loc='lower right', fontsize='small')
axes[1].legend(loc='lower left', fontsize='small')
fig.tight_layout()
return {'fig': fig, 'axes': axes, 'accuracy_lines': accuracy_lines, 'kernel_lines': kernel_lines}
state = setup() if not is_colab else {'needs_setup': True}
def update_plot(sigma_s, alpha, sigma_a, lambda_):
maybe_setup(setup, state)
sim_parameters = {
'C': C,
'sigma_s': sigma_s,
'alpha': alpha,
'sigma_a': sigma_a,
'lambda_': lambda_
}
directions = [1, -1]
time, a_all, mE_all, k_idx_all, choices, is_correct = generate_sims_conditions(
ks, directions, sim_parameters, num_sims_per_condition
)
for i, k in enumerate(ks):
mask = (k_idx_all == k)
is_corr_k = is_correct[mask, :]
perf = is_corr_k.mean(axis=0)
ci95 = 1.96 * is_corr_k.std(axis=0) / np.sqrt(mask.sum())
update_errorbar(state['accuracy_lines'][i], time, perf, yerr=ci95)
psy_kernel = (
mE_all[ (choices[:, -1] == 1) & mask ].mean(axis=0) -
mE_all[ (choices[:, -1] != 1) & mask ].mean(axis=0)
)
state['kernel_lines'][i].set_data(time, psy_kernel)
state['fig'].tight_layout()
draw_figure(state['fig'])
style = {'description_width': '150px'}
layout = Layout(width='600px')
sliders = {
'sigma_s': FloatSlider(min=0, max=5, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout),
'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout),
'sigma_a': FloatSlider(min=0, max=2, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout),
'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout)
}
interact_manual.options(manual_name='run simulations')(
update_plot,
**sliders
)
def bin_spikes(raw_spike_matrix, bin_size=50):
num_bins = raw_spike_matrix.shape[1] // bin_size
truncated_raw_spike_matrix = raw_spike_matrix[:, :num_bins * bin_size, :]
binned_spike_matrix = truncated_raw_spike_matrix.reshape([
truncated_raw_spike_matrix.shape[0],
num_bins,
-1,
truncated_raw_spike_matrix.shape[2]
]).sum(axis=2)
return binned_spike_matrix
def get_binned_spike_matrix(mat_data):
raw_spike_matrix = mat_data['RawSpikeMatrix1'][:, 149:1000, :]
binned_spike_matrix = bin_spikes(raw_spike_matrix)
binned_spike_matrix = np.sqrt(binned_spike_matrix)
time = np.arange(binned_spike_matrix.shape[1]) * 50
return time, binned_spike_matrix
def plot_single_neuron(mat_data):
setup_matplotlib_magic()
time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
def setup():
fig, axes = plt.subplots(figsize=(6.5, 4.5))
neuron_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='mean firing rate with 95% CI')
axes.set(
ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
xlabel='time [ms]',
xlim=(0, 800)
)
axes.legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line}
state = setup()
def update_plot(neuron_idx):
maybe_setup(setup, state)
update_errorbar(
state['neuron_line'],
time,
binned_spike_matrix[:, :, neuron_idx].mean(axis=0),
yerr=sem(binned_spike_matrix[:, :, neuron_idx], axis=0) * 1.96
)
state['axes'].relim()
state['axes'].autoscale(axis='y')
state['axes'].set_title(f'Neuron #{neuron_idx}')
state['fig'].tight_layout()
draw_figure(state['fig'])
sliders = {
'neuron_idx': IntSlider(min=0, max=binned_spike_matrix.shape[2] - 1, description='neuron #', layout=Layout(width='800px'), continuous_update=continuous_update)
}
interact(update_plot, **sliders)
def plot_neuron_by_choice(mat_data):
setup_matplotlib_magic()
time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
correct_trials_mask = (mat_data['targ_cho'].flatten() == mat_data['targ_cor'].flatten())
right_choice = (mat_data['targ_cho'].flatten() == 1)
def setup():
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
choices = ['right choice (95% CI)', 'left choice (95% CI)']
correct_lines = []
for choice in choices:
correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
correct_lines += [correct_line]
incorrect_lines = []
for choice in choices:
incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
incorrect_lines += [incorrect_line]
axes[0].set(
title='correct trials',
ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
xlabel='time [ms]',
xlim=(0, 800)
)
axes[1].set(
title='incorrect trials',
xlabel='time [ms]'
)
axes[0].legend(loc='upper right', fontsize='small')
axes[1].legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
state = setup()
def update_plot(neuron_idx):
maybe_setup(setup, state)
update_errorbar(
state['correct_lines'][0],
time,
binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['correct_lines'][1],
time,
binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][0],
time,
binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][1],
time,
binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim()
state['axes'][1].relim()
state['axes'][0].autoscale(axis='y')
state['axes'][1].autoscale(axis='y')
state['fig'].suptitle(f'Neuron #{neuron_idx}')
state['fig'].tight_layout()
draw_figure(state['fig'])
sliders = {
'neuron_idx': IntSlider(min=0, max=binned_spike_matrix.shape[2] - 1, description='neuron #', layout=Layout(width='800px'), continuous_update=continuous_update)
}
interact(update_plot, **sliders)
def plot_neuron_by_coherence(mat_data):
setup_matplotlib_magic()
time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
correct_trials_mask = (mat_data['targ_cho'].flatten() == mat_data['targ_cor'].flatten())
coherences = np.sort(
np.unique(mat_data['dot_coh'])
)
coherences = coherences[[0, 3, 5]]
def setup():
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
choices = ['right choice', 'left choice']
correct_lines = []
for coherence in coherences:
correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
correct_lines += [correct_line]
incorrect_lines = []
for coherence in coherences:
incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
incorrect_lines += [incorrect_line]
axes[0].set(
title='correct trials',
ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
xlabel='time [ms]',
xlim=(0, 800)
)
axes[1].set(
title='incorrect trials',
xlabel='time [ms]'
)
axes[0].legend(loc='upper right', fontsize='small')
axes[1].legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
state = setup()
def update_plot(neuron_idx):
maybe_setup(setup, state)
for i, coherence in enumerate(coherences):
coherence_mask = (mat_data['dot_coh'].flatten() == coherence)
update_errorbar(
state['correct_lines'][i],
time,
binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0)
)
update_errorbar(
state['incorrect_lines'][i],
time,
sem(binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim()
state['axes'][1].relim()
state['axes'][0].autoscale(axis='y')
state['axes'][1].autoscale(axis='y')
state['fig'].suptitle(f'Neuron #{neuron_idx}')
state['fig'].tight_layout()
draw_figure(state['fig'])
sliders = {
'neuron_idx': IntSlider(min=0, max=binned_spike_matrix.shape[2] - 1, description='neuron #', layout=Layout(width='800px'), continuous_update=continuous_update)
}
interact(update_plot, **sliders)
def calculate_deltas(mat_data):
time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
right_choice = (mat_data['targ_cho'].flatten() == 1)
mean_spikes_right = binned_spike_matrix[right_choice].mean(axis=0)
mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0)
deltas = (
trapezoid(mean_spikes_right, axis=0) -
trapezoid(mean_spikes_left, axis=0)
)
return deltas
def plot_deltas(deltas):
setup_matplotlib_magic()
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
axes[0].hist(deltas, bins=16, range=(-4, 4))
axes[1].hist(np.abs(deltas), bins=15, range=(0, 4.2))
axes[0].set(
ylabel='counts',
xlabel=r'$\Delta$'
)
axes[1].set(
xlabel=r'|$\Delta$|'
)
plt.tight_layout()
def plot_aggregated_neurons(mat_data):
setup_matplotlib_magic()
time, binned_spike_matrix = get_binned_spike_matrix(mat_data)
right_choice = (mat_data['targ_cho'].flatten() == 1)
mean_spikes_right = binned_spike_matrix[right_choice].mean(axis=0)
mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0)
deltas = calculate_deltas(mat_data)
def setup():
fig, axes = plt.subplots()
lines = [
axes.plot([], [], label='right choice')[0],
axes.plot([], [], label='left choice')[0]
]
axes.set(
ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
xlabel='time [ms]',
xlim=(0, 800)
)
axes.legend(loc='upper right')
return {'fig': fig, 'axes': axes, 'lines': lines}
state = setup()
def update_plot(delta_threshold):
maybe_setup(setup, state)
state['lines'][0].set_data(time, (mean_spikes_right * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1))
state['lines'][1].set_data(time, (mean_spikes_left * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1))
state['axes'].relim()
state['axes'].autoscale(axis='y')
state['axes'].set(
title=f'|Δ| > {delta_threshold:.2f}'
)
state['fig'].tight_layout()
draw_figure(state['fig'])
sliders = {
'delta_threshold': FloatSlider(min=0, max=np.abs(deltas).max() - 1e-3, description='threshold |Δ|', layout=Layout(width='800px'), continuous_update=continuous_update)
}
interact(update_plot, **sliders)
def simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_):
dot_coh = mat_data['dot_coh'].flatten()
dot_dir = mat_data['dot_dir'].flatten()
targ_cor = mat_data['targ_cor'].flatten()
C = np.array([0] + [1]*16)
dot_coh[dot_coh == 0] = 1e-12
k = np.unique(dot_coh)
# map directions: 0 -> 1 (right), 180 -> -1 (left)
d = np.copy(dot_dir)
d[dot_dir == 0] = 1
d[dot_dir == 180] = -1
a, _, tau, dt = generate_sims(np.outer(d, C), dot_coh, alpha, sigma_a, sigma_s, lambda_)
a = a[:, tau-1::tau]
# determine choices and correctness
cho = (a[:, -1] > 0).astype(int)
cho[cho == 0] = 2 # 2 is left, 1 is right
isCorr = cho == targ_cor
# separate correct and incorrect trials
a_Cor = a[isCorr, :]
d_Cor = d[isCorr]
cho_Cor = cho[isCorr]
coh_Cor = dot_coh[isCorr]
a_Inc = a[~isCorr, :]
d_Inc = d[~isCorr]
cho_Inc = cho[~isCorr]
coh_Inc = dot_coh[~isCorr]
# plot average accumulation for correct trials by direction
unq_dir = np.unique(d)
means_a = []
for dir_ in unq_dir:
mean_a = np.mean(a_Cor[d_Cor == dir_, :], axis=0)
means_a += [mean_a]
return means_a
def plot_sims_conditions(mat_data):
setup_matplotlib_magic()
def setup():
fig, axes = plt.subplots(figsize=(6.5, 5))
evidence_line = axes.plot([], [], color='C2', alpha=1)[0]
sim_lines = []
for choice in ['right choice', 'left choice']:
sim_line = axes.plot([], [], label=choice)[0]
sim_lines += [sim_line]
axes.set(
ylabel="mean $a$",
xlabel="time $t$",
xlim=(0, 800),
ylim=(-0.5, .5)
)
axes.legend(loc='upper right')
plt.tight_layout()
return {'fig': fig, 'axes': axes, 'sim_lines': sim_lines}
state = setup()
state['random_seed'] = 42
def update_plot(alpha, sigma_a, sigma_s, lambda_, fixed_noise):
maybe_setup(setup, state)
if fixed_noise == 'redraw noise':
state['random_seed'] = np.random.randint(0, 2**32)
np.random.seed(state['random_seed'])
means_a = simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_)
for mean_a, line in zip(means_a[::-1], state['sim_lines'], strict=True):
line.set_data(np.arange(len(mean_a)) * 50, mean_a)
state['axes'].relim()
state['axes'].autoscale(axis='y')
state['fig'].tight_layout()
draw_figure(state['fig'])
style = {'description_width': '150px'}
layout = Layout(width='600px')
sliders = {
'sigma_s': FloatSlider(min=0, max=3, step=0.01, value=0., description='fast noise (input)', style=style, layout=layout, continuous_update=continuous_update),
'alpha': FloatSlider(min=0, max=1, step=0.01, value=0., description='slow noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
'sigma_a': FloatSlider(min=0, max=1, step=0.01, value=0., description='fast inner noise (brain)', style=style, layout=layout, continuous_update=continuous_update),
'lambda_': FloatSlider(min=-5, max=5, step=0.01, value=0., description='leakiness', style=style, layout=layout, continuous_update=continuous_update),
'fixed_noise': ToggleButtons(options=['fix noise', 'redraw noise'], value='fix noise', description=' '),
}
interact(update_plot, **sliders)