Commit aa260e4a authored by Lukas Wolf's avatar Lukas Wolf
Browse files

tf repair

parent fd57ab66
......@@ -12,7 +12,6 @@ class Trainer:
"""
def __init__(self, X, y):
logging.info("Created a Trainer instance")
self.X = X
self.y = y
......@@ -21,8 +20,8 @@ class Trainer:
# Create the ensemble according to the specifications
# Run the training
# For now I will just create an Ensemble and run it
logging.info("Trainer: started training")
logging.info("------------------------------------------------------------------------------------")
logging.info("Trainer: created a {} trainer".format(config['framework']))
if config['framework'] == 'tensorflow':
from tf_models.Ensemble.Ensemble_tf import Ensemble_tf
......@@ -33,6 +32,6 @@ class Trainer:
ensemble = Ensemble_torch(nb_models=config['ensemble'], model_type=config['model'])
ensemble.run(self.X, self.y)
else:
raise Exception("Choose a valid DL framework")
raise Exception("Choose a valid deep learning framework")
logging.info("Trainer: finished training")
\ No newline at end of file
......@@ -20,8 +20,8 @@ def main():
# Load the data
try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
#trainX = np.load("./data/precomputed/processing_speed_task/sacc_fix_X.npy")
#trainY = np.load("./data/precomputed/processing_speed_task/sacc_fix_y.npy")
#trainX = np.load("./data/precomputed/calibration_task/all_fix_sacc_fix_X.npy")
#trainY = np.load("./data/precomputed/calibration_task/all_fix_sacc_fix_y.npy")
except:
raise Exception("Could not load mat data")
......
......@@ -4,7 +4,7 @@
#SBATCH --output=log/%j.out # where to store the output (%j is the JOBID), subdirectory must exist
#SBATCH --error=log/%j.err # where to store error messages
#SBATCH --gres=gpu:1
#SBATCH --mem=40G
#SBATCH --mem=80G
echo "Running on host: $(hostname)"
echo "In directory: $(pwd)"
......
......@@ -61,7 +61,6 @@ class BaseNet:
#TODO: choose this option if y has a second column
# X_train, X_val, y_train, y_val = train_val_split(x, y, 0.2, subjectID)
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
prediction_ensemble = prediction_history((X_val, y_val))
if config['model'] == 'eegnet':
# early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
......
import tensorflow as tf
from config import config
from utils.utils import *
from utils.losses import angle_loss
from tf_models.utils.losses import angle_loss
import logging
from tf_models.CNN.CNN import CNN
......@@ -18,8 +18,7 @@ class Ensemble_tf:
Default: nb_models of the model_type model
Optional: Initialize with a model list, create a more versatile Ensemble model
"""
def __init__(self, nb_models=1, model_type='cnn', model_list=[]):
def __init__(self, nb_models=5, model_type='cnn', model_list=[]):
"""
nb_models: Number of models to run in the ensemble
model_type: Default model to run
......@@ -29,7 +28,7 @@ class Ensemble_tf:
self.model_type = model_type
self.model_list = model_list
self.models = []
self._build_ensemble_model(model_list)
self._build_ensemble_model()
def __str__(self):
return self.__class__.__name__
......@@ -39,34 +38,35 @@ class Ensemble_tf:
# load all models into the model_list to predict with them
pass
def _build_ensemble_model(self, model_list):
def _build_ensemble_model(self):
"""
Create a list of compiled models
Default: create a list of self.nb_models many models of type self.model_type
Optional: create a list of the models described in self.model_list
"""
if len(model_list) > 0:
#TODO: implement loading a list of different models
if len(self.model_list) > 0:
logging.info("Built ensemble model with model(s): {}".format(self.model_list))
pass
for i in range(len(self.model_list)):
self.models.append(create_model(model_type=self.model_list[i], model_number=i))
else:
logging.info("Built ensemble model with {} {} model(s)".format(self.nb_models, self.model_type))
for i in range(self.nb_models):
self.models.append(create_model(self.model_type, model_number=i))
logging.info("Built ensemble model with {} {} model(s)".format(self.nb_models, self.model_type))
def run(self, X, y):
"""
Fit all the models in the ensemble and save them to the run directory
"""
logging.info("Started the training")
# Metrics
mse = tf.keras.losses.MeanSquaredError()
bce = tf.keras.losses.BinaryCrossentropy()
# Metrics to save over the ensemble
loss=[]
accuracy=[]
bce = tf.keras.losses.BinaryCrossentropy()
mse = tf.keras.losses.MeanSquaredError()
# Fit the models
for i in range(len(self.models)):
logging.info('Start training model number {}/{} ...'.format(i+1, self.nb_models))
print("------------------------------------------------------------------------------------")
print('Start training model number {}/{} ...'.format(i+1, self.nb_models))
model = self.models[i]
hist, pred_ensemble = model.fit(X,y)
# Collect the predictions on the validation sets
......@@ -75,19 +75,48 @@ class Ensemble_tf:
else:
for j, pred_epoch in enumerate(pred_ensemble.predhis):
pred[j] = (np.array(pred[j]) + np.array(pred_epoch))
logging.info('Finished training model number {}/{} ...'.format(i+1, self.nb_models))
print('Finished training model number {}/{} ...'.format(i+1, self.nb_models))
# Compute averaged metrics on the validation set
for j, pred_epoch in enumerate(pred):
pred_epoch = (pred_epoch / config['ensemble']).tolist() # divide by number of ensembles to get mean prediction
if config['task'] == 'prosaccade_clf':
print(f"epoch number {j}")
pred_epoch = (pred_epoch/config['ensemble']).tolist()
# Compute the ensemble loss dependent on the task
if config['task'] == 'prosaccade-clf':
print("append prosaccade loss")
loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
elif config['task'] == 'angle_reg':
loss.append(angle_loss(pred_ensemble.targets, pred_epoch).numpy())
elif config['task'] == 'angle-reg':
print("append angle loss")
loss.append(angle_loss(pred_ensemble.targets,pred_epoch).numpy())
elif config['task'] == 'gaze-reg':
loss.append(mse(pred_ensemble.targets, pred_epoch).numpy())
print("append mse loss")
loss.append(mse(pred_ensemble.targets,pred_epoch).numpy())
else:
raise Exception("Choose valid task in config.py")
pred_epoch = np.round(pred_epoch,0)
if config['task'] == 'prosaccade-clf':
accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
if config['ensemble']>1:
config['model']+='_ensemble'
if config['split']:
config['model'] = config['model'] + '_cluster'
hist.history['val_loss'] = loss
plot_loss(hist, config['model_dir'], config['model'], val = True)
if config['task'] == 'prosaccade-clf':
hist.history['val_accuracy'] = accuracy
plot_acc(hist, config['model_dir'], config['model'], val = True)
save_logs(hist, config['model_dir'], config['model'], pytorch = False)
"""
# Compute averaged metrics on the validation set
for j, pred_epoch in enumerate(pred):
print(f"pred epoch {j} with loss {self.loss_fn(pred_ensemble.targets,pred_epoch).numpy()} and accuracy {np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2)}")
pred_epoch = (pred_epoch/config['ensemble']).tolist()
loss.append(self.loss_fn(pred_ensemble.targets,pred_epoch).numpy())
pred_epoch = np.round(pred_epoch,0)
if config['task'] == 'prosaccade_clf':
accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
# save the ensemble loss to the model directory
loss_fname = config['model_dir'] + "/" + "ensemble_loss.txt"
......@@ -106,10 +135,9 @@ class Ensemble_tf:
if config['task'] == 'prosaccade_clf':
hist.history['val_accuracy'] = accuracy
plot_acc(hist, config['model_dir'], config['model'], val = True)
save_logs(hist, config['model_dir'], config['model'], pytorch = False)
save_logs(hist, config['model_dir'], config['model'], pytorch = False)
logging.info("Done with training and plotting.")
return
"""
def predict(self, X):
"""
......
......@@ -5,7 +5,7 @@ def log_config():
if config['run'] == "kerastuner":
logging.info("Running the keras-tuner")
else:
logging.info("Running the ensemble with {} {} models".format(config['ensemble'], config['model']))
logging.info("Running an ensemble with {} {} models".format(config['ensemble'], config['model']))
if config['task'] == 'gaze-reg':
logging.info("Training on the gaze regression task")
logging.info("Using data from {}".format(config['dataset']))
......
......@@ -46,6 +46,8 @@ def load_regression_data(verbose=True):
X, y, the data matrix and the labels
"""
logging.info(f"Loading regression data")
if config['task'] == 'gaze-reg':
if config['data_mode'] == 'sacc_only':
#logging.info("Using saccade only dataset")
......@@ -118,8 +120,8 @@ def get_fix_data(verbose=True):
logging.info("X training loaded.")
logging.info(X.shape)
# Save the precomputed data for further usage
#np.save("./data/precomputed/fix_only_X", X)
#np.save("./data/precomputed/fix_only_y", y)
np.save("./data/precomputed/fix_only_X" + '_' + config['task'], X)
np.save("./data/precomputed/fix_only_y" + '_' + config['task'], y)
return X, y
......@@ -185,8 +187,8 @@ def get_sacc_data(task='processing_speed_task', verbose=True):
logging.info("X training loaded.")
logging.info(X.shape)
# Save the precomputed data for future usage
#np.save("./data/precomputed/sacc_only_X", X)
#np.save("./data/precomputed/sacc_only_y", y)
np.save("./data/precomputed/sacc_only_X" + '_' + config['task'], X)
np.save("./data/precomputed/sacc_only_y" + '_' + config['task'], y)
return X, y
def get_sacc_fix_data(task='processing_speed_task', verbose=True):
......@@ -299,8 +301,8 @@ def get_sacc_fix_data(task='processing_speed_task', verbose=True):
logging.info(X.shape)
# Save the precomputed data for future usage
#np.save("./data/precomputed/sacc_fix_X", X)
#np.save("./data/precomputed/sacc_fix_y", y)
np.save("./data/precomputed/sacc_fix_X" + '_' + config['task'], X)
np.save("./data/precomputed/sacc_fix_y" + '_' + config['task'], y)
return X, y
......@@ -460,8 +462,8 @@ def get_fix_sacc_fix_data(task='processing_speed_task', verbose=True):
logging.info(X.shape)
# Save the precomputed data for future usage
#np.save("./data/precomputed/calibration_task/all_fix_sacc_fix_X", X)
#np.save("./data/precomputed/calibration_task/all_fix_sacc_fix_y", y)
np.save("./data/precomputed/calibration_task/fix_sacc_fix_X" + '_' + config['task'], X)
np.save("./data/precomputed/calibration_task/fix_sacc_fix_y" + '_' + config['task'], y)
return X, y
......@@ -767,8 +769,8 @@ def get_calibration_task_fix_sacc_fix_data(verbose=True):
logging.info(X.shape)
# Save the precomputed data for future usage
np.save("./data/precomputed/calibration_task/fix_sacc_fix_X", X)
np.save("./data/precomputed/calibration_task/fix_sacc_fix_y", y)
np.save("./data/precomputed/calibration_task/fix_sacc_fix_X" + '_' + config['task'], X)
np.save("./data/precomputed/calibration_task/fix_sacc_fix_y" + '_' + config['task'], y)
return X, y
......@@ -788,6 +790,7 @@ def load_sEEG_events(abs_dir_path):
def load_sEEG_data(abs_dir_path):
# This function is redundant, merge with above fct!!!
"""
Returns the 133 channels of a participant
129 EEG channels plus 4 (time, x, y and pupil size)
......@@ -800,14 +803,12 @@ def load_sEEG_data(abs_dir_path):
#print("EEG data shape: {}".format(data.shape))
return data # access the i-th recorded sample via data[i], recordings at 2ms intervals
def euclidian_dist(x, y):
"""
Returns the euclidian distance between x and y
"""
return np.linalg.norm(np.array(x) - np.array(y))
def cart2pol(x, y):
"""
Transform cartesian to polar coordinates
......@@ -816,7 +817,6 @@ def cart2pol(x, y):
phi = np.arctan2(y, x)
return rho, phi
def pol2cart(rho, phi):
"""
Transform polar to cartesian coordinates
......
......@@ -45,13 +45,16 @@ def plot_loss(hist, output_directory, model, val=False, savefig=True):
"""
Plot loss function of the trained model over the epochs
Works for both classification and regression, set config.py accordingly
"""
"""
if config['task'] != 'prosaccade-clf':
metric = "loss"
else:
metric = 'accuracy'
"""
epochs = len(hist.history[metric])
epochs = len(hist.history['loss'])
epochs = np.arange(epochs)
plt.figure()
......@@ -70,9 +73,10 @@ def plot_loss(hist, output_directory, model, val=False, savefig=True):
if config['task'] == 'gaze-reg':
plt.ylabel("MSE")
elif config['task'] == 'angle-reg':
plt.ylabel("Mean abs angle err")
plt.ylabel("Mean Absolute Angle Error")
else:
plt.ylabel('Binary Cross Entropy')
plt.ylabel('Binary Cross Entropy Loss')
if savefig:
plt.savefig(output_directory + '/' + model + '_loss.png')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment