diff --git a/utils.py b/utils.py index bba64b5..3aa4fea 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,7 @@ from itertools import product import numpy as np from scipy.integrate import trapezoid +from scipy.stats import sem import matplotlib.pyplot as plt from matplotlib.lines import Line2D from IPython.display import display @@ -36,6 +37,22 @@ def draw_figure(fig): plt.show() +def update_errorbar(err_container, x, y, yerr): + err_container.lines[0].set_data(x, y) + linecol = err_container.lines[2][0] + + segments = [] + for xi, yi, yerri in zip(x, y, yerr): + segments.append([[xi, yi - yerri], [xi, yi + yerri]]) + + linecol.set_segments(segments) + + if len(err_container.lines[1]) == 2: + lower_caps, upper_caps = err_container.lines[1] + lower_caps.set_data(x, y - yerr) + upper_caps.set_data(x, y + yerr) + + def maybe_setup(setup_fun, state): if not is_colab: return @@ -202,17 +219,6 @@ def plot_sims(C_size=11, num_sims=30 if not is_colab else 5): interact(update_plot, **sliders) -def update_errorbar(err_container, x, y, yerr): - err_container.lines[0].set_data(x, y) - linecol = err_container.lines[2][0] - - segments = [] - for xi, yi, yerri in zip(x, y, yerr): - segments.append([[xi, yi - yerri], [xi, yi + yerri]]) - - linecol.set_segments(segments) - - def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000): setup_matplotlib_magic() @@ -437,13 +443,15 @@ def plot_single_neuron(mat_data): def setup(): fig, axes = plt.subplots(figsize=(6.5, 4.5)) - neuron_line = axes.plot([], [])[0] + neuron_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='mean firing rate with 95% CI') axes.set( ylabel=r'$\sqrt{N_\mathrm{spikes}}$', xlabel='time [ms]', xlim=(0, 800) ) + + axes.legend(loc='upper right', fontsize='small') return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line} @@ -451,12 +459,17 @@ def plot_single_neuron(mat_data): def update_plot(neuron_idx): maybe_setup(setup, state) - - state['neuron_line'].set_data(time, binned_spike_matrix.mean(axis=0)[:, neuron_idx]) + + update_errorbar( + state['neuron_line'], + time, + binned_spike_matrix[:, :, neuron_idx].mean(axis=0), + yerr=sem(binned_spike_matrix[:, :, neuron_idx], axis=0) * 1.96 + ) state['axes'].relim() state['axes'].autoscale(axis='y') - state['axes'].set_title(f'Neuron #{neuron_idx}', fontsize='small') + state['axes'].set_title(f'Neuron #{neuron_idx}') state['fig'].tight_layout() draw_figure(state['fig']) @@ -478,15 +491,15 @@ def plot_neuron_by_choice(mat_data): def setup(): fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True) - choices = ['right choice', 'left choice'] + choices = ['right choice (95% CI)', 'left choice (95% CI)'] correct_lines = [] for choice in choices: - correct_line = axes[0].plot([], [], label=choice)[0] + correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice) correct_lines += [correct_line] incorrect_lines = [] for choice in choices: - incorrect_line = axes[1].plot([], [], label=choice)[0] + incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice) incorrect_lines += [incorrect_line] axes[0].set( @@ -499,8 +512,8 @@ def plot_neuron_by_choice(mat_data): title='incorrect trials', xlabel='time [ms]' ) - axes[0].legend(loc='upper right') - axes[1].legend(loc='upper right') + axes[0].legend(loc='upper right', fontsize='small') + axes[1].legend(loc='upper right', fontsize='small') return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines} @@ -509,16 +522,36 @@ def plot_neuron_by_choice(mat_data): def update_plot(neuron_idx): maybe_setup(setup, state) - state['correct_lines'][0].set_data(time, binned_spike_matrix[correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) - state['correct_lines'][1].set_data(time, binned_spike_matrix[correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) - state['incorrect_lines'][0].set_data(time, binned_spike_matrix[~correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) - state['incorrect_lines'][1].set_data(time, binned_spike_matrix[~correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) + update_errorbar( + state['correct_lines'][0], + time, + binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0), + sem(binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96 + ) + update_errorbar( + state['correct_lines'][1], + time, + binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0), + sem(binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96 + ) + update_errorbar( + state['incorrect_lines'][0], + time, + binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0), + sem(binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96 + ) + update_errorbar( + state['incorrect_lines'][1], + time, + binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0), + sem(binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96 + ) state['axes'][0].relim() state['axes'][1].relim() state['axes'][0].autoscale(axis='y') state['axes'][1].autoscale(axis='y') - state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small') + state['fig'].suptitle(f'Neuron #{neuron_idx}') state['fig'].tight_layout() draw_figure(state['fig']) @@ -546,12 +579,12 @@ def plot_neuron_by_coherence(mat_data): choices = ['right choice', 'left choice'] correct_lines = [] for coherence in coherences: - correct_line = axes[0].plot([], [], label=f'{coherence = :.1%}')[0] + correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)') correct_lines += [correct_line] incorrect_lines = [] for coherence in coherences: - incorrect_line = axes[1].plot([], [], label=f'{coherence = :.1%}')[0] + incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)') incorrect_lines += [incorrect_line] axes[0].set( @@ -564,8 +597,8 @@ def plot_neuron_by_coherence(mat_data): title='incorrect trials', xlabel='time [ms]' ) - axes[0].legend(loc='upper right') - axes[1].legend(loc='upper right') + axes[0].legend(loc='upper right', fontsize='small') + axes[1].legend(loc='upper right', fontsize='small') return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines} @@ -577,14 +610,25 @@ def plot_neuron_by_coherence(mat_data): for i, coherence in enumerate(coherences): coherence_mask = (mat_data['dot_coh'].flatten() == coherence) - state['correct_lines'][i].set_data(time, binned_spike_matrix[correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) - state['incorrect_lines'][i].set_data(time, binned_spike_matrix[~correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) + update_errorbar( + state['correct_lines'][i], + time, + binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0), + sem(binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96 + + ) + update_errorbar( + state['incorrect_lines'][i], + time, + binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0), + sem(binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96 + ) state['axes'][0].relim() state['axes'][1].relim() state['axes'][0].autoscale(axis='y') state['axes'][1].autoscale(axis='y') - state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small') + state['fig'].suptitle(f'Neuron #{neuron_idx}') state['fig'].tight_layout() draw_figure(state['fig']) @@ -637,13 +681,14 @@ def plot_aggregated_neurons(mat_data): mean_spikes_left = binned_spike_matrix[~right_choice].mean(axis=0) deltas = calculate_deltas(mat_data) + print(deltas.shape) def setup(): fig, axes = plt.subplots() lines = [ - axes.plot([], [], label='right choice')[0], - axes.plot([], [], label='left choice')[0] + axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='right choice'), + axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='left choice') ] axes.set( @@ -660,8 +705,18 @@ def plot_aggregated_neurons(mat_data): def update_plot(delta_threshold): maybe_setup(setup, state) - state['lines'][0].set_data(time, (mean_spikes_right * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1)) - state['lines'][1].set_data(time, (mean_spikes_left * np.sign(deltas))[:, np.abs(deltas) > delta_threshold].mean(axis=1)) + update_errorbar( + state['lines'][0], + time, + (binned_spike_matrix[right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold].mean(axis=(0, 2)), + sem((binned_spike_matrix[right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold], axis=(0, 2)) * 1.96 + ) + update_errorbar( + state['lines'][1], + time, + (binned_spike_matrix[~right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold].mean(axis=(0, 2)), + sem((binned_spike_matrix[~right_choice] * np.sign(deltas))[:, :, np.abs(deltas) > delta_threshold], axis=(0, 2)) * 1.96 + ) state['axes'].relim() state['axes'].autoscale(axis='y') @@ -716,11 +771,14 @@ def simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_): unq_dir = np.unique(d) means_a = [] + sems_a = [] for dir_ in unq_dir: mean_a = np.mean(a_Cor[d_Cor == dir_, :], axis=0) + sem_a = sem(a_Cor[d_Cor == dir_, :], axis=0) means_a += [mean_a] + sems_a += [sem_a] - return means_a + return means_a, sems_a def plot_sims_conditions(mat_data): @@ -729,10 +787,9 @@ def plot_sims_conditions(mat_data): def setup(): fig, axes = plt.subplots(figsize=(6.5, 5)) - evidence_line = axes.plot([], [], color='C2', alpha=1)[0] sim_lines = [] for choice in ['right choice', 'left choice']: - sim_line = axes.plot([], [], label=choice)[0] + sim_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice) sim_lines += [sim_line] axes.set( @@ -757,10 +814,15 @@ def plot_sims_conditions(mat_data): state['random_seed'] = np.random.randint(0, 2**32) np.random.seed(state['random_seed']) - means_a = simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_) - - for mean_a, line in zip(means_a[::-1], state['sim_lines'], strict=True): - line.set_data(np.arange(len(mean_a)) * 50, mean_a) + means_a, sems_a = simulate_conditions(mat_data, alpha, sigma_a, sigma_s, lambda_) + + for mean_a, sem_a, line in zip(means_a[::-1], sems_a[::-1], state['sim_lines'], strict=True): + update_errorbar( + line, + np.arange(len(mean_a)) * 50, + mean_a, + yerr=sem_a + ) state['axes'].relim() state['axes'].autoscale(axis='y')