In this notebook, we reproduce this figure:
%load_ext autoreload
%autoreload 2
import simulation_workshop.simulation as sim
import pandas as pd
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
T = 100 # number of trials
mu = [.2,.8] # reward probabilities
alpha = .2 # learning rate
beta = 4 # inverse temperature
actions, rewards, Cks = sim.simulate_M4ChoiceKernel_v1(T, mu, alpha, beta)
df = pd.DataFrame({'action':actions, 'reward':rewards})
def analysis_WSLS_v1(df):
df['last_action'] = df.action.shift(1)
df['last_reward'] = df.reward.shift(1)
df['stay'] = (df.action == df.last_action).astype(int)
output = df.groupby(['last_reward']).stay.mean()
loseStay = output.loc[0]
winStay = output.loc[1]
s = pd.Series([loseStay, winStay])
return s
alpha = .1 # learning rate
beta = 3 # inverse temperature
nrep = 110
data = []
for i in range(nrep):
actions, rewards, Cks = sim.simulate_M4ChoiceKernel_v1(T, mu, alpha, beta)
df = pd.DataFrame({'action':actions, 'reward':rewards})
data.append(analysis_WSLS_v1(df))
df = pd.DataFrame(data)
df.columns = ['loseStay','winStay']
ck = df.mean()
ck
alpha = .1 # learning rate
beta = 5 # inverse temperature
nrep = 110
data = []
for i in range(nrep):
actions, rewards, Cks, _ = sim.simulate_M3RescorlaWagner_v1(T, mu, alpha, beta)
df = pd.DataFrame({'action':actions, 'reward':rewards})
data.append(analysis_WSLS_v1(df))
df = pd.DataFrame(data)
df.columns = ['loseStay','winStay']
rw = df.mean()
rw
ax = ck.plot()
ax = rw.plot(ax = ax)
sns.despine()
import numpy as np
alphas = list(np.arange(.02, 1.02, .02))
betas = [1,2,5,10,20]
len(alphas) * len(betas) * 2
%%time
T = 200
data = []
for i in range(200):
for alpha in alphas:
for beta in betas:
a, r, _, _ = sim.simulate_M3RescorlaWagner_v1(T, mu, alpha, beta, starting_q_values = [0,0])
session_dict = {}
session_dict['alpha'] = alpha
session_dict['beta'] = beta
imax = np.argmax(mu) # Mu max index
a = pd.Series(a)
session_dict['correct_early'] = (a.iloc[:10] == imax).mean()
session_dict['correct_late'] = (a.iloc[-10:] == imax).mean()
data.append(session_dict)
df = pd.DataFrame(data)
.7 * np.sqrt((.9 *.9))
palette = sns.color_palette("rocket_r")[:5]
# ToDo: Define colors so that
fig, axs = plt.subplots(1,2, figsize = (14,7))
sns.lineplot(x = 'alpha', y = 'correct_early', hue = 'beta', ci = None, data = df, ax = axs[0], legend = True, palette = palette)
#
sns.lineplot(x = 'alpha', y = 'correct_late', hue = 'beta', ci = None, data = df, ax = axs[1], palette = palette, legend = False)
axs[0].legend(loc='upper right', bbox_to_anchor=(1.1, 1))
axs[0].set_title("early trials")
axs[1].set_title("late trials")
sns.despine()
palette = sns.color_palette("rocket_r")[:5]
# ToDo: Define colors so that
fig, axs = plt.subplots(1,2, figsize = (14,7))
sns.lineplot(x = 'alpha', y = 'correct_early', hue = 'beta', ci = None, data = df, ax = axs[0], legend = True, palette = palette)
#
sns.lineplot(x = 'alpha', y = 'correct_late', hue = 'beta', ci = None, data = df, ax = axs[1], palette = palette, legend = False)
axs[0].legend(loc='upper right', bbox_to_anchor=(1.1, 1))
axs[0].set_title("early trials")
axs[1].set_title("late trials")
sns.despine()
sns.lineplot(x = 'alpha', y = 'correct_early', hue = 'beta', ci = None, data = df)
sns.despine()
%% p(correct) analysis
alphas = [0.02:0.02:1];
betas = [1 2 5 10 20];
for n = 1:1000
n
for i = 1:length(alphas)
for j = 1:length(betas)
[a, r] = simulate_M3RescorlaWagner_v1(T, mu, alphas(i), betas(j));
[~,imax] = max(mu);
correct(i,j,n) = nanmean(a == imax);
correctEarly(i,j,n) = nanmean(a(1:10) == imax);
correctLate(i,j,n) = nanmean(a(end-9:end) == imax);
end
end
end
%% plot p(correct) behavior
figure(1);
E = nanmean(correctEarly,3);
L = nanmean(correctLate,3);
figure(1); clf;
set(gcf, 'Position', [284 498 750 300])
ax = easy_gridOfEqualFigures([0.2 0.1], [0.08 0.14 0.05 0.03]);
axes(ax(1)); hold on;
l = plot([0 1], wsls);
ylim([0 1])
set(l, 'marker', '.', 'markersize', 50, 'linewidth', 3)
leg1 = legend({'M1: random' 'M2: WSLS' 'M3: RW' 'M4: CK' 'M5: RW+CK'}, ...
'location', 'southeast');
xlabel('previous reward')
% ylabel('probability of staying')
ylabel('p(stay)')
title('stay behavior', 'fontweight', 'normal')
xlim([-0.1 1.1]);
ylim([0 1.04])
set(ax(1), 'xtick', [0 1])
set(leg1, 'fontsize', 12)
set(leg1, 'position', [0.19 0.2133 0.1440 0.2617])
set(ax(1), 'ytick', [0 0.5 1])
axes(ax(2)); hold on;
l1 = plot(alphas, E);
xlabel('learning rate, \alpha')
ylabel('p(correct)')
title('early trials', 'fontweight', 'normal')
for i = 1:length(betas)
leg{i} = ['\beta = ' num2str(betas(i))];
end
leg2 = legend(l1(end:-1:1), {leg{end:-1:1}});
set([leg1 leg2], 'fontsize', 12)
set(leg2, 'position', [0.6267 0.6453 0.1007 0.2617]);
axes(ax(3)); hold on;
l2 = plot(alphas, L);
xlabel('learning rate, \alpha')
% ylabel('p(correct)')
title('late trials', 'fontweight', 'normal')
for i = 1:length(l1)
f = (i-1)/(length(l1)-1);
set([l1(i) l2(i)], 'color', AZred*f + AZblue*(1-f));
end
set([l1 l2], 'linewidth', 3)
set(ax(3), 'yticklabel', [])
set(ax(2:3), 'ylim', [0.5 1.02])
set(ax, 'fontsize', 18, 'tickdir', 'out')
addABCs(ax(1:2), [-0.06 0.09], 32)
%% save resulting figure
saveFigurePdf(gcf, './Figures/Figure2')
© 2021 GitHub, Inc.
Terms
Privacy
Security
Status
Docs
Contact GitHub
Pricing
API
Training
Blog
About