To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 5b62a48f authored by Lukas Wolf's avatar Lukas Wolf
Browse files

load prepared data in trainer instead of main

parent b106159e
from torch._C import Value
from config import config
import logging
import numpy as np
class Trainer:
"""
......@@ -11,11 +13,12 @@ class Trainer:
Can also choose between tensorflow and pytorch implementation of the models
"""
def __init__(self, X, y):
self.X = X
self.y = y
logging.info(f"Data X shape: {X.shape}")
logging.info(f"Data y shape: {y.shape}")
def __init__(self):
# self.X = X
# self.y = y
self.X, self.y = self.load_data()
logging.info(f"Data X shape: {self.X.shape}")
logging.info(f"Data y shape: {self.y.shape}")
def train(self):
# Check if we want to run tf or torch
......@@ -25,18 +28,43 @@ class Trainer:
logging.info("------------------------------------------------------------------------------------")
logging.info("Trainer: created a {} trainer".format(config['framework']))
#TODO: load prepared data from data/prepared
if config['framework'] == 'tensorflow':
from tf_models.Ensemble.Ensemble_tf import Ensemble_tf
ensemble = Ensemble_tf(nb_models=config['ensemble'], model_type=config['model'])
ensemble.run(self.X, self.y)
elif config['framework'] == 'torch':
elif config['framework'] == 'pytorch':
from torch_models.Ensemble.Ensemble_torch import Ensemble_torch
ensemble = Ensemble_torch(nb_models=config['ensemble'], model_type=config['model'])
ensemble.run(self.X, self.y)
else:
raise Exception("Choose a valid deep learning framework")
raise ValueError("Choose a valid deep learning framework")
logging.info("Trainer: finished training")
logging.info("------------------------------------------------------------------------------------")
def load_data(self):
"""
Load the data depending on preprocessing, task and dataset chosen in config.py
Data has to be prepared with the preparator and stored in data/prepared
"""
try:
if config['task'] == 'prosaccade-clf':
data = np.load('./data/prepared/LR_task_with_antisaccade_synchronised_' + config['preprocessing'] + '.npz')
elif config['task'] == 'gaze-reg':
data = np.load('./data/prepared/Position_task_with_dots_synchronised_' + config['preprocessing'] + '.npz')
elif config['task'] == 'angle-reg':
if config['dataset'] == 'calibration_task':
data = np.load('./data/prepared/Direction_task_with_dots_synchronised_' + config['preprocessing'] + '.npz')
else:
data = np.load('./data/prepared/Direction_task_with_processing_speed_synchronised_' + config['preprocessing'] + '.npz')
else:
raise ValueError("Choose a valid task in config.py")
except:
raise ValueError("Did you prepare the data?")
X = data['EEG'] # has shape (samples, timesamples, channels)
y = data['labels']
y = y[:,1:] # Remove the subject counter from the labels
return X, y
......@@ -21,6 +21,7 @@ def main():
start_time = time.time()
# Load the data
"""
try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
#trainX = np.load("./data/precomputed/calibration_task/all_fix_sacc_fix_X.npy")
......@@ -31,9 +32,10 @@ def main():
if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape)
"""
# Create trainer that runs ensemble of models
trainer = Trainer(trainX, trainY)
trainer = Trainer()
trainer.train()
# select_best_model()
......
......@@ -21,7 +21,7 @@ def log_config():
logging.info("Using fixations between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_fixation']), (2 * config['max_fixation'])))
logging.info("Using saccades between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_saccade']), (2 * config['max_saccade'])))
else:
logging.info("Running the saccade classification (prosaccade) task")
logging.info("Running the LR task")
logging.info("------------------------------------------------------------------------------------")
logging.info("Model hyperparameters chosen in config.py:")
logging.info("Learning rate: {}".format(config['learning_rate']))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment