To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 1dffa282 authored by Lukas's avatar Lukas
Browse files

try load and store models

parent 5432df99
......@@ -2,6 +2,8 @@
Common interface to fit and predict with torch and tf ensembles
"""
from torch_models.CNN.CNN import CNN
class Ensemble:
def __init__(self, config, **params):
......@@ -12,7 +14,7 @@ class Ensemble:
self.ensemble = Ensemble_tf('mse', nb_models=2)
elif self.config['framework'] == 'pytorch':
from torch_models.Ensemble_torch import Ensemble_torch
self.ensemble = Ensemble_torch()
self.ensemble = Ensemble_torch(model=CNN, model_type='cnn', loss_fn='mse')
else:
raise ValueError("Choose a valid deep learning framework")
......
......@@ -58,8 +58,8 @@ config['input_shape'] = (500,129)
# Choose framework
##################################################################
#config['framework'] = 'pytorch'
config['framework'] = 'tensorflow'
config['framework'] = 'pytorch'
#config['framework'] = 'tensorflow'
##################################################################
# Choose models
......
......@@ -39,24 +39,24 @@ def main():
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
def load_data(self):
def load_data():
"""
Load the data depending on preprocessing, task and dataset chosen in config.py
Data has to be prepared with the preparator and stored in data/prepared
"""
try:
if self.config['task'] == 'prosaccade-clf':
if config['task'] == 'prosaccade-clf':
logging.info("Loading LR task data")
data = np.load('./data/prepared/LR_task_with_antisaccade_synchronised_' + self.config['preprocessing'] + '.npz')
elif self.config['task'] == 'gaze-reg':
data = np.load('./data/prepared/LR_task_with_antisaccade_synchronised_' + config['preprocessing'] + '.npz')
elif config['task'] == 'gaze-reg':
logging.info("Loading coordinate task data")
data = np.load('./data/prepared/Position_task_with_dots_synchronised_' + self.config['preprocessing'] + '.npz')
elif self.config['task'] == 'angle-reg' or self.config['task'] == 'amplitude-reg':
logging.info(f"Loading {self.config['task']} regression data")
if self.config['dataset'] == 'calibration_task':
data = np.load('./data/prepared/Direction_task_with_dots_synchronised_' + self.config['preprocessing'] + '.npz')
data = np.load('./data/prepared/Position_task_with_dots_synchronised_' + config['preprocessing'] + '.npz')
elif config['task'] == 'angle-reg' or config['task'] == 'amplitude-reg':
logging.info(f"Loading {config['task']} regression data")
if config['dataset'] == 'calibration_task':
data = np.load('./data/prepared/Direction_task_with_dots_synchronised_' + config['preprocessing'] + '.npz')
else:
data = np.load('./data/prepared/Direction_task_with_processing_speed_synchronised_' + self.config['preprocessing'] + '.npz')
data = np.load('./data/prepared/Direction_task_with_processing_speed_synchronised_' + config['preprocessing'] + '.npz')
else:
raise ValueError("Choose a valid task in config.py")
except:
......
......@@ -46,7 +46,7 @@ class BaseNet(nn.Module):
"""
BaseNet class for ConvNet and EEGnet to inherit common functionality
"""
def __init__(self, input_shape, output_shape, loss, epochs=50, verbose=True, model_number=0, batch_size=64):
def __init__(self, input_shape, output_shape=2, loss='mse', epochs=50, verbose=True, model_number=0, batch_size=64):
"""
Initialize common variables of models based on BaseNet, e.g. ConvNet or EEGNET
Create the common output layer dependent on the task to run
......
......@@ -15,7 +15,7 @@ class ConvNet(ABC, BaseNet):
Inherit from this class and only implement _module() and _get_nb_features_output_layer() methods
Modules are then stacked in the forward() pass of the model
"""
def __init__(self, input_shape, kernel_size=32, nb_filters=32, verbose=True, batch_size=64,
def __init__(self, input_shape, output_shape=2, loss='mse', kernel_size=32, nb_filters=32, verbose=True, batch_size=64,
use_residual=False, depth=6, epochs=2, preprocessing = False, model_number=0):
"""
We define the layers of the network in the __init__ function
......
from sklearn.utils import validation
from torch import tensor
from torch.utils import data
from config import config
import logging
......@@ -51,6 +52,9 @@ class Ensemble_torch:
logging.info('Finished fitting model number {}/{} ...'.format(i+1, self.nb_models))
def predict(self, testX):
tensor_X = torch.tensor(testX)
if torch.cuda.is_available():
tensor_X.cuda()
path = config['model_dir'] + '/best_models/'
for i, file in enumerate(os.listdir(path)):
# These 3 lines are needed for torch to load and predict
......@@ -62,6 +66,9 @@ class Ensemble_torch:
pred = model(testX)
else:
pred += model(testX)
pred = pred.detach().numpy()
print(pred )
return pred / config['ensemble'] # TODO: this might have to be rounded for majority decision in LR task
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment