Commit 1e86f6cc authored by Lukas Wolf's avatar Lukas Wolf
Browse files

improvements for models

parent 9140093a
......@@ -7,14 +7,13 @@ from config import config
from tensorflow.keras.callbacks import CSVLogger
import logging
#TODO: build similar class for regression
class prediction_history(tf.keras.callbacks.Callback):
def __init__(self,validation_data):
self.validation_data = validation_data
self.predhis = []
self.targets = validation_data[1]
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
self.predhis.append(y_pred)
......@@ -111,7 +110,7 @@ class Regression_ConvNet(ABC):
return gap_layer
# Optional: add some more dense layers here
gap_layer = tf.keras.layers.Dense(300)(gap_layer)
#gap_layer = tf.keras.layers.Dense(300)(gap_layer)
gap_layer = tf.keras.layers.Dense(50)(gap_layer)
if config['data_mode'] == "fix_sacc_fix":
......@@ -126,12 +125,17 @@ class Regression_ConvNet(ABC):
return self.model
def fit(self, x, y, verbose=2):
# Split data
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
# Define callbacks
csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)
prediction_ensemble = prediction_history((X_val,y_val))
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
prediction_ensemble = prediction_history((X_val,y_val))
# Fit model
hist = self.model.fit(X_train, y_train, verbose=verbose, batch_size=self.batch_size, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt])
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
return hist , prediction_ensemble
\ No newline at end of file
......@@ -169,10 +169,12 @@ class Siamese_ConvNet(ABC):
return self.model
def fit(self, x, y, verbose=2):
# Define callbacks
csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
prediction_ensemble = prediction_history(((X_val_sac, X_val_fix), y_val))
# Prepare the data as tuples for the siamese input
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
......@@ -182,10 +184,8 @@ class Siamese_ConvNet(ABC):
logging.info("Saccade siamese input shape: {}".format(X_train_sac[0].shape))
logging.info("Fixation siamese input shape: {}".format(X_train_fix[0].shape))
prediction_ensemble = prediction_history(((X_val_sac, X_val_fix), y_val))
hist = self.model.fit((X_train_sac, X_train_fix), y_train, verbose=verbose, batch_size=self.batch_size,
validation_data=((X_val_sac, X_val_fix), y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt])
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
return hist , prediction_ensemble
......@@ -23,12 +23,6 @@ config['root_dir'] = '.'
##################################################################
# You can modify the rest or add new fields as you need.
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-3 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-5
config['regularization'] = 1e-2 # fix only: 0, sac only: 1e-2, sac_fix: 1e-1, fix_sac_fix: 0.5
config['epochs'] = 250
config['batch_size'] = 64
"""
Parameters that can be chosen:
......@@ -49,15 +43,36 @@ finally used for classification.
Cluster can be set to clustering(), clustering2() or clustering3(), where different clusters based on literature are used.
"""
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-3 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4
config['regularization'] = 1e-1 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1e-1, fix_sac_fix: 5
config['epochs'] = 100
config['batch_size'] = 64
# Choose experiment
config['gaze-reg'] = True # Set to False if you want to run the saccade classification task
config['data-fraction'] = 1.0 # Set to 1.0 if you want to use the whole dataset, experimental feature only for regression task \
# Choose which dataset to run the gaze regression on
#config['data_mode'] = 'fix_only'
#config['data_mode'] = 'sacc_only'
config['data_mode'] = 'sacc_fix'
#config['data_mode'] = 'fix_sacc_fix'
# Choose model
config['model'] = 'cnn'
#config['model'] = 'inception'
#config['model'] = 'eegnet'
#config['model'] = 'deepeye'
#config['model'] = 'xception'
#config['model'] = 'pyramidal_cnn'
#config['model'] = 'siamese' # Note that you have to set data_mode to sacc_fix for this model
# Choose the kerastuner or an ensemble of models
#config['run'] = 'kerastuner'
config['run'] = 'ensemble'
config['ensemble'] = 5 #number of models in the ensemble method
# Set loss automatically depending on the dataset/task to run
if config['data_mode'] == 'fix_sacc_fix':
from utils.losses import angle_loss
......@@ -65,20 +80,6 @@ if config['data_mode'] == 'fix_sacc_fix':
else:
config['loss'] = 'mean_squared_error'
# Choose to either run the kerastuner on the model or
#config['run'] = 'kerastuner'
config['run'] = 'ensemble'
config['ensemble'] = 1 #number of models in the ensemble method
# Choosing model
#config['model'] = 'cnn'
config['model'] = 'inception'
#config['model'] = 'eegnet'
#config['model'] = 'deepeye'
#config['model'] = 'xception'
#config['model'] = 'pyramidal_cnn'
#config['model'] = 'siamese' # Note that you have to set data_mode to sacc_fix for this
# Options for classification task, currently not used for regression
config['downsampled'] = False
config['split'] = False
......@@ -89,15 +90,11 @@ if config['gaze-reg']:
config['trainY_file'] = 'label.mat'
config['trainX_variable'] = 'EEGdata'
config['trainY_variable'] = 'label'
config['padding'] = 'repeat' # options: zero, repeat #TODO: find more options for clever padding
config['min_fixation'] = 50 # choose a minimum length for the gaze fixation
config['max_fixation'] = 150 # choose a maximum length for the gaze fixation
config['min_saccade'] = 10 # minimum number of samples for a saccade that we want to use
config['max_saccade'] = 30 # maximum number of samples for a saccade that we want to use
config['x_screen'] = 600
config['y_screen'] = 800 #TODO: Kick out measurements where people look somewhere off the screen
......@@ -111,7 +108,7 @@ else:
# Define parameter for each model
# CNN
config['cnn'] = {}
# CNN
# PyrCNN
config['pyramidal_cnn'] = {}
# InceptionTime
config['inception'] = {}
......@@ -124,41 +121,36 @@ config['eegnet'] = {}
# DeepEye-RNN
config['deepeye-rnn'] = {}
# Set the input shapes dependent on task and dataset
if config['gaze-reg']:
if config['data_mode'] == 'fix_only':
config['cnn']['input_shape'] = (int(config['max_fixation']), 129) # e.g. for max_duration 300 we have shape (150,129)
config['pyramidal_cnn']['input_shape'] = (int(config['max_fixation']), 129)
config['inception']['input_shape'] = (int(config['max_fixation']), 129)
config['deepeye']['input_shape'] = (int(config['max_fixation']), 129)
config['xception']['input_shape'] = (int(config['max_fixation']), 129)
elif config['data_mode'] == 'sacc_only':
config['cnn']['input_shape'] = (config['max_saccade'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'], 129)
config['inception']['input_shape'] = (config['max_saccade'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'], 129)
config['xception']['input_shape'] = (config['max_saccade'], 129)
elif config['data_mode'] == 'sacc_fix':
config['cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
else: # data mode is fix_sacc_fix
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
#TODO: EEGnet not yet implemented for regression
#config['deepeye-rnn']['input_shape'] = (int(config['max_duration']), 129)
#config['eegnet']['channels'] = 129
#config['eegnet']['samples'] = config['max_duration'] = 150
else:
# Left-right classification task
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
......@@ -167,7 +159,6 @@ else:
config['deepeye']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye-rnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['xception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['eegnet']['channels'] = 129
config['eegnet']['samples'] = 125 if config['downsampled'] else 500
......
This diff is collapsed.
This diff is collapsed.
......@@ -2,6 +2,7 @@ import tensorflow as tf
from config import config
from utils.utils import *
from utils.losses import angle_loss
import logging
from CNN.Regression_CNN import Regression_CNN
......@@ -43,7 +44,7 @@ def run(trainX, trainY):
elif config['model'] == 'inception':
reg = Regression_INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, batch_size=config['batch_size'],
kernel_size=64, nb_filters=32, depth=6, bottleneck_size=32, epochs=config['epochs'],
kernel_size=64, nb_filters=64, depth=12, bottleneck_size=32, epochs=config['epochs'],
learning_rate=config['learning_rate'], regularization=config['regularization'])
elif config['model'] == 'xception':
......@@ -88,26 +89,35 @@ def run(trainX, trainY):
for j, pred_epoch in enumerate(pred):
pred_epoch = (pred_epoch / config['ensemble']).tolist()
loss.append(mse(pred_ensemble.targets, pred_epoch).numpy())
if config['data_mode'] == 'fix_sacc_fix':
loss.append(angle_loss(pred_ensemble.targets, pred_epoch).numpy())
else:
loss.append(mse(pred_ensemble.targets, pred_epoch).numpy())
pred_epoch = np.round(pred_epoch, 0)
accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
#accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
# save the ensemble loss to the model directory
loss_fname = config['model_dir'] + "/" + "ensemble_loss.txt"
np.savetxt(loss_fname, loss, delimiter=',')
if(config['ensemble'] == 1):
# Only one model, just plot the loss
plot_loss(hist, config['model_dir'], config['model'], val = True)
elif config['ensemble'] > 1:
# Ensemble, plot the ensemble loss
config['model']+='_ensemble'
config['model'] += '_ensemble'
hist.history['val_loss'] = np.array(loss)
# plot ensemble loss
plot_loss(hist, config['model_dir'], config['model'], val = True)
if not config['gaze-reg']:
# Plot also accuracy for the saccade classification task
hist.history['val_accuracy'] = accuracy
plot_acc(hist, config['model_dir'], config['model'], val = True)
if config['split']:
config['model'] = config['model'] + '_cluster'
#if config['split']:
#config['model'] = config['model'] + '_cluster'
logging.info("Done with training and plotting.")
save_logs(hist, config['model_dir'], config['model'], pytorch = False)
#TODO: rewrite the function below to properly store stats and results
#save_logs(hist, config['model_dir'], config['model'], pytorch = False)
......@@ -15,7 +15,6 @@ else:
from ensemble import run # (anti-) saccade task
from kerasTuner import tune
def main():
logging.basicConfig(filename=config['info_log'], level=logging.INFO)
logging.info('Started the Logging')
......@@ -61,7 +60,6 @@ def main():
else:
raise Exception("Please choose a valid run scheme in config.py")
# Select model and plot results
# select_best_model() TODO: review this for the regression task
# comparison_plot_loss()
......
......@@ -49,13 +49,18 @@ def plot_loss(hist, output_directory, model, val=False, savefig=True):
epochs = len(hist.history[metric])
epochs = np.arange(epochs)
logging.info("Length of hist.history[loss]: {}".format(len(hist.history['loss'])))
logging.info("Length of hist.history[val_loss]: {}".format(len(hist.history['val_loss'])))
logging.info("Length of epochs: {}".format(len(epochs)))
plt.figure()
plt.title(model + ' loss')
plt.plot(np.array(epochs), np.array(hist.history['loss']), 'b-', label='training')
# plot the training curve
plt.plot(epochs, np.array(hist.history['loss']), 'b-', label='training')
# plot the validation curve
if val:
plt.plot(np.array(epochs), np.array(hist.history['val_loss']),'g-',label='validation')
plt.plot(epochs, np.array(hist.history['val_loss']),'g-',label='validation')
plt.legend()
plt.xlabel('epochs')
......@@ -66,7 +71,7 @@ def plot_loss(hist, output_directory, model, val=False, savefig=True):
if savefig:
plt.savefig(output_directory + '/' + model + '_loss.png')
plt.show()
#plt.show()
def plot_loss_torch(loss, output_directory, model):
......@@ -193,7 +198,7 @@ def save_logs(hist, output_directory, model, pytorch=False):
df_best_model.to_csv(output_directory + '/' + model + '_' + 'df_best_model.csv', index=False)
except:
return
return
# Save the model parameters (newly added without debugging)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment