To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit b86e145c authored by Lukas Wolf's avatar Lukas Wolf
Browse files

eegneet in torch now equivalent to our tf implementation

parent 5ede05a3
......@@ -14,6 +14,8 @@ class Trainer:
def __init__(self, X, y):
self.X = X
self.y = y
logging.info(f"Data X shape: {X.shape}")
logging.info(f"Data y shape: {y.shape}")
def train(self):
# Check if we want to run tf or torch
......@@ -23,6 +25,8 @@ class Trainer:
logging.info("------------------------------------------------------------------------------------")
logging.info("Trainer: created a {} trainer".format(config['framework']))
#TODO: load prepared data from data/prepared
if config['framework'] == 'tensorflow':
from tf_models.Ensemble.Ensemble_tf import Ensemble_tf
ensemble = Ensemble_tf(nb_models=config['ensemble'], model_type=config['model'])
......
......@@ -53,7 +53,7 @@ if config['task'] != 'prosaccade-clf':
##################################################################
config['framework'] = 'torch'
#config['framework'] ='tensorflow'
#config['framework'] = 'tensorflow'
##################################################################
# Choose model
......@@ -62,10 +62,11 @@ config['framework'] = 'torch'
config['ensemble'] = 2 #number of models in the ensemble
config['pretrained'] = False # We can use a model pretrained on processing speed task
config['model'] = 'cnn'
#config['model'] = 'cnn'
#config['model'] = 'inception'
#config['model'] = 'eegnet'
config['model'] = 'eegnet'
#config['model'] = 'xception'
#config['model'] = 'gazenet'
#config['model'] = 'pyramidal_cnn'
#config['model'] = 'deepeye'
#config['model'] = 'deepeye-rnn'
......@@ -74,9 +75,9 @@ config['model'] = 'cnn'
##################################################################
# Hyper-parameters and training configuration.
##################################################################
config['learning_rate'] = 1e-3 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4, for inception on angle 1e-5
config['learning_rate'] = 1e-2 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4, for inception on angle 1e-5
config['regularization'] = 0 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1, fix_sac_fix: 5, for inception on angle 0
config['epochs'] = 5
config['epochs'] = 2
config['batch_size'] = 64
##################################################################
......@@ -132,6 +133,8 @@ config['cnn'] = {}
config['pyramidal_cnn'] = {}
# InceptionTime
config['inception'] = {}
# GazeNet (for event detection task)
config['gazenet'] = {}
# DeepEye
config['deepeye'] = {}
# Xception
......@@ -165,6 +168,7 @@ if config['task'] != 'prosaccade-clf':
config['cnn']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['gazenet']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['deepeye-rnn']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
......@@ -174,6 +178,7 @@ if config['task'] != 'prosaccade-clf':
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
config['gazenet']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
config['deepeye-rnn']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
......@@ -184,6 +189,7 @@ else:
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['pyramidal_cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['gazenet']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye-rnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['xception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
......
load('all_EEGprocuesan.mat');
for i = 1: size(all_EEGprocuesan,2)
for ii = 1:size(all_EEGprocuesan,3)
final(:,ii,i) = all_EEGprocuesan(:,i,ii);
end
end
cd('\\130.60.169.45\methlab\Neurometric\Antisaccades\code\eeglab14_1_2b')
eeglab;
close all
X.srate = 500
X.nbchan = 129
X.pnts = 500
X.trials = 1
X.xmin = 0
X.event = []
X.setname = []
for i = 1:size(final,1)
X.data = final(i,:,:);
downsamplEEG(i) = pop_resample(X,125);
end
save('downsamplEEG', 'downsamplEEG', '-v7.3')
% EEGprocuesan = 129X500X73
% fixEEGprocuesan = 73 X 500X129
clc
clear
x = dir('\\130.60.169.45\methlab\ETH_AS\preprocessed2')
subjects = {x.name};
subjects = {subjects{4:end-3}}';
clear x
cd('\\130.60.169.45\methlab\ETH_AS')
%%
all_EEGprocuesan = []
for subj = 1:100%186 - BA5 didnt work, 346- BY2
datapath = strcat('\\130.60.169.45\methlab\ETH_AS\preprocessed2\',subjects{subj});
cd (datapath)
if exist(strcat('EEGprocuesan.mat')) > 0
datafile= strcat('EEGprocuesan.mat');
load (datafile);
end
final_EEGprocuesan = [];
for i = 1: size(EEGprocuesan.data,1)
for ii = 1:size(EEGprocuesan.data,3)
final_EEGprocuesan(ii,:,i) = EEGprocuesan.data(i,:,ii);
end
end
all_EEGprocuesan = vertcat(all_EEGprocuesan ,final_EEGprocuesan);
size(all_EEGprocuesan,1);
end
save('all_EEGprocuesan', 'all_EEGprocuesan', '-v7.3')
\ No newline at end of file
% EEGprocuesan = 129X500X73
% fixEEGprocuesan = 73 X 500X129
clc
clear
x = dir('\\130.60.169.45\methlab\ETH_AS\preprocessed2')
subjects = {x.name};
subjects = {subjects{4:end-3}}';
clear x
cd('\\130.60.169.45\methlab\ETH_AS')
%%
all_trialinfoprosan = []
for subj = 1:100 %= 5 %186 - BA5 didnt work, 346- BY2
datapath = strcat('\\130.60.169.45\methlab\ETH_AS\preprocessed2\',subjects{subj});
cd (datapath)
if exist(strcat('trialinfoprosan.mat')) > 0
datafile= strcat('trialinfoprosan.mat');
load (datafile);
end
B = trialinfoprosan.cues;
A = all_trialinfoprosan;
all_trialinfoprosan = vertcat(A,B);
size(all_trialinfoprosan,1)
end
save('all_trialinfoprosan', 'all_trialinfoprosan', '-v7.3')
\ No newline at end of file
......@@ -11,8 +11,8 @@ preparation_config = dict()
# 'Direction_task' (dataset: 'dots' or 'processing_speed'):
# 'Position_task' (dataset: 'dots'):
# 'Segmentation_task' (dataset: 'antisaccade', 'dots', or 'processing_speed'):
preparation_config['task'] = 'Position_task'
preparation_config['dataset'] = 'dots'
preparation_config['task'] = 'LR_task'
preparation_config['dataset'] = 'antisaccade'
# We provide two types of preprocessing on the dataset (minimal preprocessing and maximal preprocessing). Choices are
# 'max'
......
clc
clear
cd('\\130.60.169.45\methlab\Neurometric\Antisaccades\code\eeglab14_1_2b')
eeglab;
close all
x = dir('\\130.60.169.45\methlab\ETH_AS\preprocessed2')
subjects = {x.name};
subjects = {subjects{4:end-3}}';
clear x
cd('\\130.60.169.45\methlab\ETH_AS')
%%
for subj = 1:length(subjects) %186 - BA5 didnt work, 346- BY2
datapath = strcat('\\130.60.169.45\methlab\ETH_AS\preprocessed2\',subjects{subj});
cd (datapath)
if exist(strcat('gip_',subjects{subj},'_AS_EEG.mat')) > 0
datafile= strcat('gip_',subjects{subj},'_AS_EEG.mat');
load (datafile)
elseif exist(strcat('oip_',subjects{subj},'_AS_EEG.mat')) > 0
datafile= strcat('oip_',subjects{subj},'_AS_EEG.mat');
load (datafile)
end
%% Re-reference to average reference
EEG = pop_reref(EEG,[]);
%% triggers renaming
countblocks = 1;
for e = 1:length(EEG.event)
if strcmp(EEG.event(e).type, 'boundary')
countblocks = countblocks + 1;
continue;
end
if countblocks == 2 || countblocks == 3 || countblocks == 4 % antisaccade blocks
if strcmp(EEG.event(e).type,'10 ') % change 10 to 12 for AS
EEG.event(e).type = '12 ';
elseif strcmp(EEG.event(e).type,'11 ')
EEG.event(e).type = '13 '; % change 11 to 13 for AS
end
if strcmp(EEG.event(e).type,'40 ')
EEG.event(e).type = '41 ';
end
end
end
EEG.event(strcmp('boundary',{EEG.event.type})) = [];
rmEventsIx = strcmp('L_fixation',{EEG.event.type});
rmEv = EEG.event(rmEventsIx);
EEG.event(rmEventsIx) = [];
EEG.event(1).dir = []; %left or right
EEG.event(1).cond = [];%pro or anti
%% rename EEG.event.type
previous = '';
for e = 1:length(EEG.event)
if strcmp(EEG.event(e).type, 'L_saccade')
if strcmp(previous, '10 ')
EEG.event(e).type = 'saccade_pro_left'
EEG.event(e).cond = 'pro';
EEG.event(e).dir = 'left';
%pro left
elseif strcmp(previous, '11 ')
EEG.event(e).type = 'saccade_pro_right'
EEG.event(e).cond = 'pro';
EEG.event(e).dir = 'right';
elseif strcmp(previous, '12 ')
EEG.event(e).type = 'saccade_anti_left'
EEG.event(e).cond = 'anti';
EEG.event(e).dir = 'left';
elseif strcmp(previous, '13 ')
EEG.event(e).type = 'saccade_anti_right'
EEG.event(e).cond = 'anti';
EEG.event(e).dir = 'right';
else
EEG.event(e).type = 'invalid';
end
end
if ~strcmp(EEG.event(e).type, 'L_fixation') ...
&& ~strcmp(EEG.event(e).type, 'L_blink')
previous = EEG.event(e).type;
end
end
%% remove everything from EEG.event which is not saccade or trigger
tmpinv=find(strcmp({EEG.event.type}, 'invalid') | strcmp({EEG.event.type}, 'L_blink'))
EEG.event(tmpinv)=[]
%% removing errors
% if 10 and the sub didn't look left then error
% pro left sac_start_x > sac_endpos_x --> correct condition
tmperrsacc1=find(strcmp({EEG.event.type}, 'saccade_pro_left') & [EEG.event.sac_startpos_x]< [EEG.event.sac_endpos_x]);
tmperr1=[tmperrsacc1 (tmperrsacc1-1)];
EEG.event(tmperr1)=[];
tmperrsacc2=find(strcmp({EEG.event.type}, 'saccade_anti_left') & [EEG.event.sac_startpos_x]> [EEG.event.sac_endpos_x]);
tmperr2=[tmperrsacc2 (tmperrsacc2-1)];
EEG.event(tmperr2)=[];
tmperrsacc3=find(strcmp({EEG.event.type}, 'saccade_pro_right') & [EEG.event.sac_startpos_x]> [EEG.event.sac_endpos_x]);
tmperr3=[tmperrsacc3 (tmperrsacc3-1)]
EEG.event(tmperr3)=[];
tmperrsacc4=find(strcmp({EEG.event.type}, 'saccade_anti_right') & [EEG.event.sac_startpos_x]< [EEG.event.sac_endpos_x]);
tmperr4=[tmperrsacc4 (tmperrsacc4-1)];
EEG.event(tmperr4)=[];
%% amplitude too small
tmperrsacc6=find(strcmp({EEG.event.type}, 'saccade_pro_right') ...
& [EEG.event.sac_amplitude]<1.5)
tmperrsacc7=find(strcmp({EEG.event.type}, 'saccade_pro_left') ...
& [EEG.event.sac_amplitude]<1.5)
tmperrsacc8=find(strcmp({EEG.event.type}, 'saccade_anti_left') ...
& [EEG.event.sac_amplitude]<1.5)
tmperrsacc9=find(strcmp({EEG.event.type}, 'saccade_anti_right') ...
& [EEG.event.sac_amplitude]<1.5)
tmperr69=[tmperrsacc6 (tmperrsacc6-1) tmperrsacc7 (tmperrsacc7-1) tmperrsacc8 (tmperrsacc8-1) tmperrsacc9 (tmperrsacc9-1)]
EEG.event(tmperr69)=[];
clear tmperrsacc1 tmperrsacc2 tmperrsacc3 tmperrsacc4 tmperrsacc6 tmperrsacc7 tmperrsacc8 tmperrsacc9
%% delete cues where there was no saccade afterwards
% tmperrcue10 = []
% tmperrcue11 = []
%
%start with pro left cue 10
tmperrcue10= find(strcmp({EEG.event.type}, '10 ')) ;
for iii=1:length(tmperrcue10)
pos = tmperrcue10(iii)
if ~ (strcmp(EEG.event(pos+1).type , 'saccade_pro_left'))
EEG.event(pos).type='missingsacc'; %cue
end
end
%%11
tmperrcue11 = find(strcmp({EEG.event.type}, '11 ')) ;
for iii=1:length(tmperrcue11)
pos = tmperrcue11(iii)
if ~ (strcmp(EEG.event(pos+1).type , 'saccade_pro_right'))
EEG.event(pos).type='missingsacc'; %cue
end
end
tmpinv=find(strcmp({EEG.event.type}, 'missingsacc')) ;
EEG.event(tmpinv)=[];
%% delete saccades and cues when the saccade comes faster than 100ms after cue
tmpevent=length(EEG.event)
saccpro=find(strcmp({EEG.event.type},'saccade_pro_right')==1 | strcmp({EEG.event.type},'saccade_pro_left')==1)% find rows where there is a saccade
saccanti=find(strcmp({EEG.event.type},'saccade_anti_right')==1 | strcmp({EEG.event.type},'saccade_anti_left')==1);%find rows where there is a saccade
for b=1:size(saccpro,2)
if (EEG.event(saccpro(1,b)).latency-EEG.event(saccpro(1,b)-1).latency)<50 %50 because 100ms
EEG.event(saccpro(b)).type='micro'; %saccade
EEG.event(saccpro(b)-1).type = 'micro'; %cue
end
end
for b=1:size(saccanti,2)
if (EEG.event(saccanti(b)).latency-EEG.event(saccanti(1,b)-1).latency)<50;
EEG.event(saccanti(b)-1).type ='micro';
EEG.event(saccanti(b)).type ='micro';
end
end
tmpinv=find(strcmp({EEG.event.type}, 'micro')) ;
EEG.event(tmpinv)=[];
%% epoching
EEGprocuesan= pop_epoch(EEG, {'10','11'}, [0, 1]);
%how many epochs
trialinfoprosan.epochs=size(EEGprocuesan.data, 3);
%% important
tmp=find(strcmp({EEGprocuesan.event.type}, '11 ') | strcmp({EEGprocuesan.event.type}, '10 '))
right= find(strcmp({EEGprocuesan.event(tmp).type},'11 ')==1);
left= find(strcmp({EEGprocuesan.event(tmp).type},'10 ')==1);
trialinfoprosan.cues = nan(length(tmp),1);
trialinfoprosan.cues(left)= 0;
trialinfoprosan.cues(right)= 1;
%% save epoched data
if size(EEGprocuesan.data,3) ~= size(trialinfoprosan.cues,1)
error('this is bad')
end
save EEGprocuesan EEGprocuesan
save trialinfoprosan trialinfoprosan
end
......@@ -88,6 +88,8 @@ class ConvNet(ABC, BaseNet):
output_layer = tf.keras.layers.Dense(2, activation='linear')(gap_layer)
else: #elif config['task'] == 'angle-reg':
output_layer = tf.keras.layers.Dense(1, activation='linear')(gap_layer)
#else:
# pass #TODO: implement for event detection task
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
return model
......@@ -116,9 +116,20 @@ class EEGNet(BaseNet):
flatten = Flatten()(block2)
if config['split']:
return flatten
else:
dense = Dense(self.nb_classes, name='dense',
kernel_constraint=max_norm(self.norm_rate))(flatten)
softmax = Activation('sigmoid', name='sigmoid')(dense)
return Model(inputs=input1, outputs=softmax)
#else:
# dense = Dense(self.nb_classes, name='dense',
# kernel_constraint=max_norm(self.norm_rate))(flatten)
# softmax = Activation('sigmoid', name='sigmoid')(dense)
if config['task'] == 'prosaccade_clf':
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(flatten)
elif config['task'] == 'gaze-reg':
output_layer = tf.keras.layers.Dense(2, activation='linear')(flatten)
else: #elif config['task'] == 'angle-reg':
output_layer = tf.keras.layers.Dense(1, activation='linear')(flatten)
#else:
#TODO: implement for event detection task
# pass
return Model(inputs=input1, outputs=output_layer)
......@@ -29,15 +29,7 @@ class Ensemble_tf:
self.model_type = model_type
self.model_list = model_list
self.models = []
self._build_ensemble_model()
def __str__(self):
return self.__class__.__name__
def load_models(self, path_to_models):
#TODO: implement
# load all models into the model_list to predict with them
pass
#self._build_ensemble_model()
def _build_ensemble_model(self):
"""
......@@ -65,11 +57,13 @@ class Ensemble_tf:
mse = tf.keras.losses.MeanSquaredError()
# Fit the models
for i in range(len(self.models)):
for i in range(config['ensemble']):
print("------------------------------------------------------------------------------------")
print('Start training model number {}/{} ...'.format(i+1, self.nb_models))
model = self.models[i]
if config['plot_model'] and i == 0:
model = create_model(self.model_type, i)
print("created model")
print(f"config plotmodel {config['plot_model']}")
if config['plot_model']:
plot_model(model.get_model())
hist, pred_ensemble = model.fit(X,y)
# Collect the predictions on the validation sets
......
......@@ -3,8 +3,7 @@ from torch import nn
import numpy as np
from config import config
import logging
from torch.utils.tensorboard import SummaryWriter
#from torch.utils.tensorboard import SummaryWriter
from torch_models.torch_utils.training import train_loop, test_loop
class Prediction_history:
......@@ -81,31 +80,32 @@ class BaseNet(nn.Module):
"""
Fit the model on the dataset defined by data x and labels y
"""
print("------------------------------------------------------------------------------------")
print(f"Fitting model number {self.model_number}")
logging.info("------------------------------------------------------------------------------------")
logging.info(f"Fitting model number {self.model_number}")
# Create the optimizer
optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
# Create a history to track ensemble performance
prediction_ensemble = Prediction_history(dataloader=test_dataloader)
# Create a summary writer for logging metrics
writer = SummaryWriter(log_dir=config['model_dir']+'/summary_writer')
#writer = SummaryWriter(log_dir=config['model_dir']+'/summary_writer')
# Train the model
epochs = config['epochs']
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
logging.info(f"Epoch {t+1}\n-------------------------------")
# Run through training and test set
train_loss = train_loop(train_dataloader, self.float(), self.loss_fn, optimizer)
test_loss, test_acc = test_loop(test_dataloader, self.float(), self.loss_fn)
# Add the predictions on the validation set
prediction_ensemble.on_epoch_end(model=self)
logging.info("end epoch")
# Log metrics to the writer
writer.add_scalar('Loss/train', train_loss, t)
writer.add_scalar('Loss/test', test_loss, t)
if config['task'] == 'prosaccade-clf':
writer.add_scalar('Accuracy/test', test_acc, t)
print(f"Finished model number {self.model_number}")
if config['save_model']:
#writer.add_scalar('Loss/train', train_loss, t)
#writer.add_scalar('Loss/test', test_loss, t)
#if config['task'] == 'prosaccade-clf':
#writer.add_scalar('Accuracy/test', test_acc, t)
logging.info(f"Finished model number {self.model_number}")
if config['save_models'] and self.model_number==0:
ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
torch.save(self.state_dict(), ckpt_dir)
return prediction_ensemble
\ No newline at end of file
......@@ -11,9 +11,10 @@ from torch_models.Modules import Pad_Conv, Pad_Pool
class ConvNet(ABC, BaseNet):
"""
This class defines all the common functionality for more complex convolutional nets
This class defines all the common functionality for convolutional nets
Inherit from this class and only implement _module() and _get_nb_features_output_layer() methods
Modules are then stacked in the forward() method
"""
def __init__(self, input_shape, kernel_size=32, nb_filters=32, verbose=True, batch_size=64,
use_residual=False, depth=6, epochs=2, preprocessing = False, model_number=0):
"""
......
......@@ -3,60 +3,147 @@ import torch.nn as nn
import torch.nn.functional as F
from torch_models.BaseNetTorch import BaseNet
from config import config
import logging
from torch_models.Modules import Pad_Conv2d, Pad_Pool2d
class EEGNet(BaseNet):
"""
The EEGNet architecture used as baseline. This is the architecture explained in the paper
def __init__(self, input_shape, epochs=50, model_number=0):
super().__init__(input_shape=input_shape, epochs=epochs, model_number=model_number)
'EEGNet: A Compact Convolutional Network for EEG-based Brain-Computer Interfaces' with authors
Vernon J. Lawhern, Amelia J. Solon, Nicholas R. Waytowich, Stephen M. Gordon, Chou P. Hung, Brent J. Lance
"""
def __init__(self, input_shape, epochs=50, model_number=0,
F1=16, F2=256, verbose=True, D=4, kernel_size=256,
dropout_rate=0.5):
# NOTE: This dimension will depend on the number of timestamps per sample in your data.
# I have 120 timepoints.
self.T = self.input_shape[0] # input_shape like (500 timepoints, 129 channels)
# Layer 1
self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)
self.batchnorm1 = nn.BatchNorm2d(16, False)
self.kernel_size = kernel_size
self.timesamples = input_shape[0]
self.channels = input_shape[1]
self.F1 = F1
self.D = D
self.F2 = F2