Commit 006f064a authored by Lukas Wolf's avatar Lukas Wolf
Browse files

functionality for calibration task

parent 5f819931
......@@ -52,8 +52,8 @@ config['data-fraction'] = 1.0 # Set to 1.0 if you want to use the whole dataset,
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4
config['regularization'] = 5 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1, fix_sac_fix: 5
config['epochs'] = 1
config['regularization'] = 20 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1, fix_sac_fix: 5
config['epochs'] = 50
config['batch_size'] = 64
# Choose which dataset to run the gaze regression on
......@@ -64,8 +64,8 @@ config['batch_size'] = 64
config['data_mode'] = 'calib_task_fix_sacc_fix'
# Choose model
#config['model'] = 'cnn'
config['model'] = 'inception'
config['model'] = 'cnn'
#config['model'] = 'inception'
#config['model'] = 'eegnet'
#config['model'] = 'deepeye'
#config['model'] = 'xception'
......@@ -82,7 +82,7 @@ config['sanity_check'] = False
config['plot_model'] = True
# Set loss automatically depending on the dataset/task to run
if config['data_mode'] == 'fix_sacc_fix' and config['gaze-reg']:
if (config['data_mode'] == 'fix_sacc_fix' or config['data_mode'] == 'calib_task_fix_sacc_fix') and config['gaze-reg']:
from utils.losses import angle_loss
config['loss'] = angle_loss
else:
......@@ -139,24 +139,35 @@ if config['gaze-reg']:
config['inception']['input_shape'] = (int(config['max_fixation']), 129)
config['deepeye']['input_shape'] = (int(config['max_fixation']), 129)
config['xception']['input_shape'] = (int(config['max_fixation']), 129)
elif config['data_mode'] == 'sacc_only':
config['cnn']['input_shape'] = (config['max_saccade'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'], 129)
config['inception']['input_shape'] = (config['max_saccade'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'], 129)
config['xception']['input_shape'] = (config['max_saccade'], 129)
elif config['data_mode'] == 'sacc_fix':
config['cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
else: # data mode is fix_sacc_fix
elif config['data_mode'] == 'fix_sacc_fix':
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
elif config['data_mode'] == 'calib_task_fix_sacc_fix':
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + 2 * config['max_fixation'], 129)
#TODO: EEGnet not yet implemented for regression
#config['deepeye-rnn']['input_shape'] = (int(config['max_duration']), 129)
#config['eegnet']['channels'] = 129
......
......@@ -4,7 +4,7 @@
#SBATCH --output=log/%j.out # where to store the output (%j is the JOBID), subdirectory must exist
#SBATCH --error=log/%j.err # where to store error messages
#SBATCH --gres=gpu:1
#SBATCH --mem=40G
#SBATCH --mem=80G
echo "Running on host: $(hostname)"
echo "In directory: $(pwd)"
......
......@@ -499,8 +499,8 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
non_stimulus_strings = fixations + saccades + blinks
non_stimulus_ints = [blockOff, blockOn, stimulusOff] + dots_start
print(non_stimulus_ints)
print(non_stimulus_strings)
#print(non_stimulus_ints)
#print(non_stimulus_strings)
#print(dots_start)
# Threshold for the distance between fixation_avg_pos and the expected label
......@@ -545,7 +545,7 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
event_1_name = event_1[0][0] # dereference the event name, e.g. 'L_saccade'
#event_1_name = hashlib.md5(str(event_1_name).encode()).hexdigest()
print("{}".format(i) + ": " + event_1_name)
#print("{}".format(i) + ": " + event_1_name)
try:
num = int(event_1_name)
......@@ -682,7 +682,7 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
fixation1_datapoint = np.array(data[fixation1_start_time:fixation1_end_time])
x_len_fix1, y_len_fix1 = fixation1_datapoint.shape
if x_len_fix1 < config['min_fixation']: # in this task there is no upper limit on fixation length
if x_len_fix1 < config['min_fixation'] or x_len_fix1 > config['max_fixation']: # in this task there is no upper limit on fixation length
#print("fixation1 not sufficient: {}".format(x_len_fix1))
continue
......@@ -720,7 +720,7 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
fixation2_datapoint = np.array(data[fixation2_start_time:fixation2_end_time])
x_len_fix2, y_len_fix2 = fixation2_datapoint.shape
if x_len_fix2 < config['min_fixation']: # no upper bound on fixation length
if x_len_fix2 < config['min_fixation'] or x_len_fix2 > config['max_fixation']: # no upper bound on fixation length
#print("fixation2 not sufficient: {}".format(x_len_fix2))
continue
......@@ -736,7 +736,7 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
x_datapoint = np.concatenate((fixation1_datapoint, saccade_datapoint), axis=0)
x_datapoint = np.concatenate((x_datapoint, fixation2_datapoint), axis=0)
print(x_datapoint.shape)
#print(x_datapoint.shape)
# Compute difference of the fixation coordinates
dx = fix2_avg_x - fix1_avg_x
......@@ -757,7 +757,7 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
X = np.asarray(x_list)
#print(X.shape)
#X = X[:,:,:129] # Cut off the last 4 columns (time, x, y, pupil size)
X = X[:,:,:129] # Cut off the last 4 columns (time, x, y, pupil size)
# Normalize the data
#norm = np.linalg.norm(X)
#X = X / norm
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment