Commit df4e6e4c authored by Lukas Wolf's avatar Lukas Wolf
Browse files

cleaned the config file

parent 4c0fc04b
......@@ -6,8 +6,6 @@ from Clusters.cluster import clustering as clustering
from Clusters.cluster2 import clustering as clustering2
from Clusters.cluster3 import clustering as clustering3
from keras.utils.generic_utils import get_custom_objects
config = dict()
##################################################################
......@@ -27,11 +25,33 @@ config['root_dir'] = '.'
"""
Parameters that can be chosen:
TODO: write a proper description how to set the fields in the config
"""
# Choose which task to run
#config['task'] = 'prosaccade-clf'
#config['task'] = 'gaze-reg'
config['task'] = 'angle-reg'
# Choose from which experiment the dataset to load. Can only be chosen for angle-pred and gaze-reg
# TODO: also make calibration task data available for gaze-reg
if config['task'] != 'prosaccade-clf':
config['dataset'] = 'processing_speed_task'
#config['dataset'] = 'calibration_task'
# Choose which data to use for gaze-reg
if config['task'] == 'gaze-reg':
#config['data_mode'] = 'fix_only'
#config['data_mode'] = 'sacc_only'
config['data_mode'] = 'sacc_fix'
elif config['task'] == 'angle-reg':
config['data_mode'] = 'fix_sacc_fix'
gaze-reg: set this to true if you want to work on the
regression task (Lukas thesis)
The corresponding data is EEGdata-002.mat and label.mat
# Choose how much data to use on gaze-reg
config['data-fraction'] = 1.0 # Set to 1.0 if you want to use the whole dataset, experimental feature only for regression task \
"""
Parameters that can be chosen:
cnn: The simple CNN architecture
inception: The InceptionTime architecture
eegnet: The EEGNet architecture
......@@ -44,27 +64,6 @@ If split set to true, the data will be clustered and fed each to a separate arch
finally used for classification.
Cluster can be set to clustering(), clustering2() or clustering3(), where different clusters based on literature are used.
"""
# Choose experiment TODO: create config['experiment'] which can be chosen like the model via commenting out
config['gaze-reg'] = True # Set to False if you want to run the saccade classification task
config['prosaccade'] = False
config['calibration-task'] = False
config['angle-pred'] = False
# Choose how much data to use on gaze-reg
config['data-fraction'] = 1.0 # Set to 1.0 if you want to use the whole dataset, experimental feature only for regression task \
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4
config['regularization'] = 20 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1, fix_sac_fix: 5
config['epochs'] = 100
config['batch_size'] = 64
# Choose which dataset to run the gaze regression on
#config['data_mode'] = 'fix_only'
#onfig['data_mode'] = 'sacc_only'
#config['data_mode'] = 'sacc_fix'
#config['data_mode'] = 'fix_sacc_fix'
config['data_mode'] = 'calib_task_fix_sacc_fix'
# Choose model
config['model'] = 'cnn'
......@@ -75,6 +74,12 @@ config['model'] = 'cnn'
#config['model'] = 'pyramidal_cnn'
#config['model'] = 'siamese' # Note that you have to set data_mode to sacc_fix for this model
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4, calib_task
config['regularization'] = 1 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1, fix_sac_fix: 5
config['epochs'] = 150
config['batch_size'] = 64
# Choose the kerastuner or an ensemble of models
#config['run'] = 'kerastuner'
config['run'] = 'ensemble'
......@@ -82,42 +87,30 @@ config['ensemble'] = 1 #number of models in the ensemble method
# Other functions that can be chosen optionally
config['sanity_check'] = False
config['plot_model'] = True
config['plot_model'] = False
# Set loss automatically depending on the dataset/task to run
if (config['data_mode'] == 'fix_sacc_fix' or config['data_mode'] == 'calib_task_fix_sacc_fix') and config['gaze-reg']:
if config['task'] == 'angle-reg':
from utils.losses import angle_loss
config['loss'] = angle_loss
get_custom_objects().update({"angle_loss": angle_loss})
else:
config['loss'] = 'mean_squared_error'
# Options for classification task, currently not used for regression
# Options for prosaccade task, currently not used for regression
config['downsampled'] = False
config['split'] = False
#config['cluster'] = clustering()
if config['gaze-reg']:
config['trainX_file'] = 'EEGdata-002.mat'
config['trainY_file'] = 'label.mat'
config['trainX_variable'] = 'EEGdata'
config['trainY_variable'] = 'label'
config['padding'] = 'repeat' # options: zero, repeat #TODO: find more options for clever padding
config['min_fixation'] = 50 # choose a minimum length for the gaze fixation
config['max_fixation'] = 150 # choose a maximum length for the gaze fixation
if config['data_mode'] == 'calib_task_fix_sacc_fix':
if config['task'] != 'prosaccade-clf':
config['padding'] = 'repeat' # options: zero, repeat
config['min_fixation'] = 50 # min number of samples for the gaze fixation
config['max_fixation'] = 150 # max number of samples for the gaze fixation
if config['dataset'] == 'calibration_task' or config['dataset'] == 'processing_speed_task':
config['max_fixation'] = 1000
config['min_saccade'] = 10 # minimum number of samples for a saccade that we want to use
config['max_saccade'] = 30 # maximum number of samples for a saccade that we want to use
config['x_screen'] = 600
config['y_screen'] = 800 #TODO: Kick out measurements where people look somewhere off the screen
else:
# Left right classification task
config['trainX_file'] = 'noweEEG.mat' if config['downsampled'] else 'all_EEGprocuesan.mat'
config['trainY_file'] = 'all_trialinfoprosan.mat'
config['trainX_variable'] = 'noweEEG' if config['downsampled'] else 'all_EEGprocuesan'
config['trainY_variable'] = 'all_trialinfoprosan'
config['y_screen'] = 800
# Define parameter for each model
# CNN
......@@ -136,48 +129,45 @@ config['eegnet'] = {}
config['deepeye-rnn'] = {}
# Set the input shapes dependent on task and dataset
if config['gaze-reg']:
if config['task'] != 'prosaccade-clf':
if config['data_mode'] == 'fix_only':
config['cnn']['input_shape'] = (int(config['max_fixation']), 129) # e.g. for max_duration 300 we have shape (150,129)
config['pyramidal_cnn']['input_shape'] = (int(config['max_fixation']), 129)
config['inception']['input_shape'] = (int(config['max_fixation']), 129)
config['deepeye']['input_shape'] = (int(config['max_fixation']), 129)
config['xception']['input_shape'] = (int(config['max_fixation']), 129)
elif config['data_mode'] == 'sacc_only':
config['cnn']['input_shape'] = (config['max_saccade'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'], 129)
config['inception']['input_shape'] = (config['max_saccade'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'], 129)
config['xception']['input_shape'] = (config['max_saccade'], 129)
elif config['data_mode'] == 'sacc_fix':
config['cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
elif config['data_mode'] == 'fix_sacc_fix':
# Choose the shapes for angle pred depending on dataset
elif config['data_mode'] == 'fix_sacc_fix' and config['dataset'] == 'processing_speed_task':
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
elif config['data_mode'] == 'calib_task_fix_sacc_fix':
elif config['data_mode'] == 'fix_sacc_fix' and config['dataset'] == 'calibration_task':
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
#TODO: EEGnet not yet implemented for regression
# These models are not yet implemented for regression
#config['deepeye-rnn']['input_shape'] = (int(config['max_duration']), 129)
#config['eegnet']['channels'] = 129
#config['eegnet']['samples'] = config['max_duration'] = 150
else:
# Left-right classification task
# Left-right classification (prosaccade) task
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['pyramidal_cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
......@@ -191,11 +181,10 @@ else:
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']
# Modify the model folder name depending on which task tuns
if config['gaze-reg']:
model_folder_name += '_gaze-reg'
model_folder_name += "_" + config['task']
if config['task'] != 'prosaccade-clf':
model_folder_name += '_' + config['data_mode']
else:
model_folder_name += '_left-right-pred'
model_folder_name += '_' + config['dataset']
if config['split']:
model_folder_name += '_cluster'
......
......@@ -8,34 +8,43 @@ import h5py
import logging
import time
if config['gaze-reg']:
if config['task'] == 'gaze-reg' or config['task'] == 'angle-reg':
from ensemble_regression import run # gaze regression task
from kerasTuner_regression import tune
else:
elif config['task'] == 'prosaccade-clf':
from ensemble import run # (anti-) saccade task
from kerasTuner import tune
else:
raise Exception("Choose valid task in config.py")
def main():
logging.basicConfig(filename=config['info_log'], level=logging.INFO)
logging.info('Started the Logging')
start_time = time.time()
if config['gaze-reg']:
if config['task'] == 'gaze-reg':
logging.info("Running the gaze regression task")
logging.info("Using data from {}".format(config['dataset']))
logging.info("Using {} padding".format(config['padding']))
if config["data_mode"] != "sacc_only":
logging.info("Using fixations between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_fixation']), (2 * config['max_fixation'])))
if config['data_mode'] != "fix_only":
logging.info("Using saccades between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_saccade']), (2 * config['max_saccade'])))
elif config['task'] == 'angle-reg':
logging.info("Running the angle regression task")
logging.info("Using data from {}".format(config['dataset']))
logging.info("Using {} padding".format(config['padding']))
logging.info("Using fixations between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_fixation']), (2 * config['max_fixation'])))
logging.info("Using saccades between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_saccade']), (2 * config['max_saccade'])))
else:
logging.info("Running the saccade classification task")
logging.info("Running the saccade classification (prosaccade) task")
# Log some parameters to better distinguish between tasks
logging.info("Loss: {}".format(config['loss']))
logging.info("Learning rate: {}".format(config['learning_rate']))
logging.info("Regularization: {}".format(config['regularization']))
logging.info("Batch size: {}".format(config['batch_size']))
logging.info("Maximal number of epochs: {}".format(config['epochs']))
logging.info("Maximal number of epochs per model: {}".format(config['epochs']))
if config['run'] == "kerastuner":
logging.info("Running the keras-tuner")
......
......@@ -4,7 +4,7 @@
#SBATCH --output=log/%j.out # where to store the output (%j is the JOBID), subdirectory must exist
#SBATCH --error=log/%j.err # where to store error messages
#SBATCH --gres=gpu:1
#SBATCH --mem=80G
#SBATCH --mem=120G
echo "Running on host: $(hostname)"
echo "In directory: $(pwd)"
......
......@@ -9,9 +9,15 @@ from utils import regression_preprocessing
def get_mat_data(data_dir, verbose=True):
if config['gaze-reg']:
# call the regression task data loader
return get_regression_data(verbose=verbose)
if config['task'] != 'prosaccade-clf':
# call the regression task data loader which handles the further processing
return regression_preprocessing.load_regression_data(verbose=verbose)
# Load the data for the prosaccade task
config['trainX_file'] = 'noweEEG.mat' if config['downsampled'] else 'all_EEGprocuesan.mat'
config['trainY_file'] = 'all_trialinfoprosan.mat'
config['trainX_variable'] = 'noweEEG' if config['downsampled'] else 'all_EEGprocuesan'
config['trainY_variable'] = 'all_trialinfoprosan'
with h5py.File(data_dir + config['trainX_file'], 'r') as f:
X = f[config['trainX_variable']][:]
......@@ -38,17 +44,6 @@ def get_mat_data(data_dir, verbose=True):
return X, y
def get_regression_data(verbose=True):
"""
Load the data for the regression task
Returns X, y: data and labels for the task
"""
X, y = regression_preprocessing.load_regression_data(verbose=verbose)
return X, y
def get_pickle_data(data_dir, verbose=True):
pkl_file_x = open(data_dir + 'x.pkl', 'rb')
x = pickle.load(pkl_file_x)
......
......@@ -46,30 +46,37 @@ def load_regression_data(verbose=True):
X, y, the data matrix and the labels
"""
if config['data_mode'] == 'sacc_only':
logging.info("Using saccade only dataset")
return get_sacc_data(verbose=verbose)
elif config['data_mode'] == 'sacc_fix':
logging.info("Using saccade-fixation dataset")
return get_sacc_fix_data(verbose=verbose)
elif config['data_mode'] == 'fix_sacc_fix':
if config['task'] == 'gaze-reg':
if config['data_mode'] == 'sacc_only':
logging.info("Using saccade only dataset")
return get_sacc_data(verbose=verbose)
elif config['data_mode'] == 'fix_only':
logging.info("Using fixation only dataset")
return get_fix_data(verbose=verbose)
elif config['data_mode'] == 'sacc_fix':
logging.info("Using saccade-fixation dataset")
return get_sacc_fix_data(verbose=verbose)
else:
raise Exception("Choose valid task and data_mode in config.py")
elif config['task'] == 'angle-reg':
logging.info("Using fixation-saccade-fixation dataset")
return get_fix_sacc_fix_data(verbose=verbose)
elif config['data_mode'] == 'fix_only':
logging.info("Using fixation only dataset")
return get_fix_data(verbose=verbose)
elif config['data_mode'] == 'calib_task_fix_sacc_fix':
logging.info("Using fixation-saccade-fixation from calibration task dataset")
return get_fix_sacc_fix_data(verbose=verbose)
# Choose the following to extract only data with precise fixations
#return get_calibration_task_fix_sacc_fix_data(verbose=verbose)
else:
raise Exception("Choose a valid data_mode in config.py")
raise Exception("Choose a valid task in config.py")
def get_fix_data(verbose=True):
"""
Returns X, y for the gaze regression task with EEG data X only from fixations
"""
# Define these variables that are needed to load the fixation data (old dataset, processing speed?)
config['trainX_file'] = 'EEGdata-002.mat'
config['trainY_file'] = 'label.mat'
config['trainX_variable'] = 'EEGdata'
config['trainY_variable'] = 'label'
# Load the labels
y = scipy.io.loadmat(config['data_dir'] + config['trainY_variable'])
labels = y['label'] # shape (85413, 1) for label.mat
......
......@@ -42,14 +42,14 @@ def plot_loss(hist, output_directory, model, val=False, savefig=True):
"""
Works for both classification and regression, set config.py accordingly
"""
if config['gaze-reg']:
if config['task'] != 'prosaccade-clf':
metric = "loss"
else:
metric = 'accuracy'
epochs = len(hist.history[metric])
epochs = np.arange(epochs)
plt.figure()
plt.title(model + ' loss')
# plot the training curve
......@@ -60,8 +60,10 @@ def plot_loss(hist, output_directory, model, val=False, savefig=True):
plt.legend()
plt.xlabel('epochs')
if config['gaze-reg']:
if config['task'] == 'gaze-reg':
plt.ylabel("MSE")
elif config['task'] == 'angle-reg':
plt.ylabel("Mean abs angle err")
else:
plt.ylabel('Binary Cross Entropy')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment