Add errorbars

imperator 1 year ago
parent 0f2a06a74b
commit 5aa7ad89de

@ -1,6 +1,7 @@
from itertools import product from itertools import product
import numpy as np import numpy as np
from scipy.integrate import trapezoid from scipy.integrate import trapezoid
from scipy.stats import sem
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.lines import Line2D from matplotlib.lines import Line2D
from IPython.display import display from IPython.display import display
@ -36,6 +37,22 @@ def draw_figure(fig):
plt.show() plt.show()
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
if len(err_container.lines[1]) == 2:
lower_caps, upper_caps = err_container.lines[1]
lower_caps.set_data(x, y - yerr)
upper_caps.set_data(x, y + yerr)
def maybe_setup(setup_fun, state): def maybe_setup(setup_fun, state):
if not is_colab: if not is_colab:
return return
@ -202,17 +219,6 @@ def plot_sims(C_size=11, num_sims=30 if not is_colab else 5):
interact(update_plot, **sliders) interact(update_plot, **sliders)
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000): def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000):
setup_matplotlib_magic() setup_matplotlib_magic()
@ -437,7 +443,7 @@ def plot_single_neuron(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(figsize=(6.5, 4.5)) fig, axes = plt.subplots(figsize=(6.5, 4.5))
neuron_line = axes.plot([], [])[0] neuron_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='mean firing rate with 95% CI')
axes.set( axes.set(
ylabel=r'$\sqrt{N_\mathrm{spikes}}$', ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
@ -445,6 +451,8 @@ def plot_single_neuron(mat_data):
xlim=(0, 800) xlim=(0, 800)
) )
axes.legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line} return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line}
state = setup() state = setup()
@ -452,11 +460,16 @@ def plot_single_neuron(mat_data):
def update_plot(neuron_idx): def update_plot(neuron_idx):
maybe_setup(setup, state) maybe_setup(setup, state)
state['neuron_line'].set_data(time, binned_spike_matrix.mean(axis=0)[:, neuron_idx]) update_errorbar(
state['neuron_line'],
time,
binned_spike_matrix[:, :, neuron_idx].mean(axis=0),
yerr=sem(binned_spike_matrix[:, :, neuron_idx], axis=0) * 1.96
)
state['axes'].relim() state['axes'].relim()
state['axes'].autoscale(axis='y') state['axes'].autoscale(axis='y')
state['axes'].set_title(f'Neuron #{neuron_idx}', fontsize='small') state['axes'].set_title(f'Neuron #{neuron_idx}')
state['fig'].tight_layout() state['fig'].tight_layout()
draw_figure(state['fig']) draw_figure(state['fig'])
@ -478,15 +491,15 @@ def plot_neuron_by_choice(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True) fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
choices = ['right choice', 'left choice'] choices = ['right choice (95% CI)', 'left choice (95% CI)']
correct_lines = [] correct_lines = []
for choice in choices: for choice in choices:
correct_line = axes[0].plot([], [], label=choice)[0] correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
correct_lines += [correct_line] correct_lines += [correct_line]
incorrect_lines = [] incorrect_lines = []
for choice in choices: for choice in choices:
incorrect_line = axes[1].plot([], [], label=choice)[0] incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
incorrect_lines += [incorrect_line] incorrect_lines += [incorrect_line]
axes[0].set( axes[0].set(
@ -499,8 +512,8 @@ def plot_neuron_by_choice(mat_data):
title='incorrect trials', title='incorrect trials',
xlabel='time [ms]' xlabel='time [ms]'
) )
axes[0].legend(loc='upper right') axes[0].legend(loc='upper right', fontsize='small')
axes[1].legend(loc='upper right') axes[1].legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines} return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
@ -509,16 +522,36 @@ def plot_neuron_by_choice(mat_data):
def update_plot(neuron_idx): def update_plot(neuron_idx):
maybe_setup(setup, state) maybe_setup(setup, state)
state['correct_lines'][0].set_data(time, binned_spike_matrix[correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) update_errorbar(
state['correct_lines'][1].set_data(time, binned_spike_matrix[correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) state['correct_lines'][0],
state['incorrect_lines'][0].set_data(time, binned_spike_matrix[~correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) time,
state['incorrect_lines'][1].set_data(time, binned_spike_matrix[~correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['correct_lines'][1],
time,
binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][0],
time,
binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][1],
time,
binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim() state['axes'][0].relim()
state['axes'][1].relim() state['axes'][1].relim()
state['axes'][0].autoscale(axis='y') state['axes'][0].autoscale(axis='y')
state['axes'][1].autoscale(axis='y') state['axes'][1].autoscale(axis='y')
state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small') state['fig'].suptitle(f'Neuron #{neuron_idx}')
state['fig'].tight_layout() state['fig'].tight_layout()
draw_figure(state['fig']) draw_figure(state['fig'])
@ -546,12 +579,12 @@ def plot_neuron_by_coherence(mat_data):
choices = ['right choice', 'left choice'] choices = ['right choice', 'left choice']
correct_lines = [] correct_lines = []
for coherence in coherences: for coherence in coherences:
correct_line = axes[0].plot([], [], label=f'{coherence = :.1%}')[0] correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
correct_lines += [correct_line] correct_lines += [correct_line]
incorrect_lines = [] incorrect_lines = []
for coherence in coherences: for coherence in coherences:
incorrect_line = axes[1].plot([], [], label=f'{coherence = :.1%}')[0] incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
incorrect_lines += [incorrect_line] incorrect_lines += [incorrect_line]
axes[0].set( axes[0].set(
@ -564,8 +597,8 @@ def plot_neuron_by_coherence(mat_data):
title='incorrect trials', title='incorrect trials',
xlabel='time [ms]' xlabel='time [ms]'
) )
axes[0].legend(loc='upper right') axes[0].legend(loc='upper right', fontsize='small')
axes[1].legend(loc='upper right') axes[1].legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines} return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
@ -577,14 +610,25 @@ def plot_neuron_by_coherence(mat_data):
for i, coherence in enumerate(coherences): for i, coherence in enumerate(coherences):
coherence_mask = (mat_data['dot_coh'].flatten() == coherence) coherence_mask = (mat_data['dot_coh'].flatten() == coherence)
state['correct_lines'][i].set_data(time, binned_spike_matrix[correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) update_errorbar(
state['incorrect_lines'][i].set_data(time, binned_spike_matrix[~correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) state['correct_lines'][i],
time,
binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][i],
time,
binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim() state['axes'][0].relim()
state['axes'][1].relim() state['axes'][1].relim()
state['axes'][0].autoscale(axis='y') state['axes'][0].autoscale(axis='y')
state['axes'][1].autoscale(axis='y') state['axes'][1].autoscale(axis='y')
state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small') state['fig'].suptitle(f'Neuron #{neuron_idx}')
state['fig'].tight_layout() state['fig'].tight_layout()
draw_figure(state['fig']) draw_figure(state['fig'])
@ -637,13 +681,14 @@ def plot_aggregated_neurons(mat_data):
mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0) mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0)
deltas = calculate_deltas(mat_data) deltas = calculate_deltas(mat_data)
print(deltas.shape)
def setup(): def setup():
fig, axes = plt.subplots() fig, axes = plt.subplots()
lines = [ lines = [
axes.plot([], [], label='right choice')[0], axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='right choice'),
axes.plot([], [], label='left choice')[0] axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='left choice')
] ]
axes.set( axes.set(
@ -660,8 +705,18 @@ def plot_aggregated_neurons(mat_data):
def update_plot(delta_threshold): def update_plot(delta_threshold):
maybe_setup(setup, state) maybe_setup(setup, state)
state['lines'][0].set_data(time, (mean_spikes_right * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1)) update_errorbar(
state['lines'][1].set_data(time, (mean_spikes_left * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1)) state['lines'][0],
time,
(binned_spike_matrix[right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold].mean(axis=(0, 2)),
sem((binned_spike_matrix[right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold], axis=(0, 2)) * 1.96
)
update_errorbar(
state['lines'][1],
time,
(binned_spike_matrix[~right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold].mean(axis=(0, 2)),
sem((binned_spike_matrix[~right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold], axis=(0, 2)) * 1.96
)
state['axes'].relim() state['axes'].relim()
state['axes'].autoscale(axis='y') state['axes'].autoscale(axis='y')
@ -716,11 +771,14 @@ def simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_):
unq_dir = np.unique(d) unq_dir = np.unique(d)
means_a = [] means_a = []
sems_a = []
for dir_ in unq_dir: for dir_ in unq_dir:
mean_a = np.mean(a_Cor[d_Cor == dir_, :], axis=0) mean_a = np.mean(a_Cor[d_Cor == dir_, :], axis=0)
sem_a = sem(a_Cor[d_Cor == dir_, :], axis=0)
means_a += [mean_a] means_a += [mean_a]
sems_a += [sem_a]
return means_a return means_a, sems_a
def plot_sims_conditions(mat_data): def plot_sims_conditions(mat_data):
@ -729,10 +787,9 @@ def plot_sims_conditions(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(figsize=(6.5, 5)) fig, axes = plt.subplots(figsize=(6.5, 5))
evidence_line = axes.plot([], [], color='C2', alpha=1)[0]
sim_lines = [] sim_lines = []
for choice in ['right choice', 'left choice']: for choice in ['right choice', 'left choice']:
sim_line = axes.plot([], [], label=choice)[0] sim_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
sim_lines += [sim_line] sim_lines += [sim_line]
axes.set( axes.set(
@ -757,10 +814,15 @@ def plot_sims_conditions(mat_data):
state['random_seed'] = np.random.randint(0, 2**32) state['random_seed'] = np.random.randint(0, 2**32)
np.random.seed(state['random_seed']) np.random.seed(state['random_seed'])
means_a = simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_) means_a, sems_a = simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_)
for mean_a, line in zip(means_a[::-1], state['sim_lines'], strict=True): for mean_a, sem_a, line in zip(means_a[::-1], sems_a[::-1], state['sim_lines'], strict=True):
line.set_data(np.arange(len(mean_a)) * 50, mean_a) update_errorbar(
line,
np.arange(len(mean_a)) * 50,
mean_a,
yerr=sem_a
)
state['axes'].relim() state['axes'].relim()
state['axes'].autoscale(axis='y') state['axes'].autoscale(axis='y')

Loading…
Cancel
Save