Commit d2ec40d2 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

new custom loss and accuracy computation for prediction history across ensembles in torch

parent af594851
......@@ -6,18 +6,22 @@ import torch.nn as nn
import os
import numpy as np
from torch_models.torch_utils.custom_losses import angle_loss, cross_entropy, mse
from torch_models.torch_utils.custom_losses import angle_loss
from torch_models.torch_utils.utils import compute_loss, compute_accuracy
from sklearn.model_selection import train_test_split
from torch_models.torch_utils.dataloader import create_dataloader
from torch_models.CNN.CNN import CNN
from torch_models.InceptionTime.InceptionTime import Inception
#from torch_models.EEGNet.eegNet import EEGNet
class Ensemble_torch:
"""
The Ensemble is a model itself, which contains a number of models that are averaged on prediction.
The Ensemble is a model itself, which contains a number of models whose prediction is averaged.
Default: nb_models of the model_type model
Optional: Initialize with a model list, create a more versatile Ensemble model
"""
def __init__(self, nb_models=1, model_type='cnn', model_list=[]):
"""
nb_models: Number of models to run in the ensemble
......@@ -29,55 +33,65 @@ class Ensemble_torch:
self.model_list = model_list
self.models = []
self._build_ensemble_model()
def load_models(self, path_to_models):
for model_name in os.listdir(path_to_models):
# load models and save them in self.models
# Adapt nb_models
# Create model_list as string list for completeness
pass
# Set the loss function
if config['task'] == 'prosaccade-clf':
self.loss_fn = nn.BCELoss()
elif config['task'] == 'angle-reg':
self.loss_fn = angle_loss
elif config['task'] == 'gaze-reg':
self.loss_fn = nn.MSELoss
def run(self, X, y):
def run(self, x, y):
"""
Fit all the models in the ensemble and save them to the run directory
"""
# Metrics to save over the ensemble
# Metrics to save across the ensemble
loss=[]
accuracy=[]
bce = cross_entropy
# Create a split
x = np.transpose(x, (0, 2, 1)) # (batch_size, samples, channels) to (bs, ch, samples) as torch conv layers want it
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
# Create dataloaders
train_dataloader = create_dataloader(X_train, y_train, batch_size=config['batch_size'])
test_dataloader = create_dataloader(X_val, y_val, batch_size=config['batch_size'])
# Fit the models
for i in range(len(self.models)):
print("------------------------------------------------------------------------------------")
print('Start training model number {}/{} ...'.format(i+1, self.nb_models))
logging.info("------------------------------------------------------------------------------------")
logging.info('Start training model number {}/{} ...'.format(i+1, self.nb_models))
model = self.models[i]
pred_ensemble = model.fit(X,y)
pred_ensemble = model.fit(train_dataloader, test_dataloader)
# Collect the predictions on the validation sets
if i == 0:
pred = pred_ensemble.predhis
else:
for j, pred_epoch in enumerate(pred_ensemble.predhis):
pred[j] = (np.array(pred[j]) + np.array(pred_epoch))
print('Finished training model number {}/{} ...'.format(i+1, self.nb_models))
for batch, predictions in enumerate(pred_epoch):
pred[j][batch] = pred[j][batch] + predictions
#pred[j] = pred[j] + pred_epoch
logging.info('Finished training model number {}/{} ...'.format(i+1, self.nb_models))
logging.info("------------------------------------------------------------------------------------")
logging.info("Finished ensemble training")
print(f"pred size: {len(pred)}")
print(f"pred type: {type(pred)}")
print(f"pred[0] size: {len(pred[0])}")
print(f"pred[0] type: {type(pred[0])}")
print(f"pred[0]0 size: {len(pred[0][0])}")
print(f"pred[0]0 type: {type(pred[0][0])}")
for j, pred_epoch in enumerate(pred_ensemble.predhis):
print(f"predensemble predhis size: {len(pred_ensemble.predhis)}")
for j, pred_epoch in enumerate(pred):
print(f"epoch number {j}")
pred_epoch = (pred_epoch/config['ensemble']).tolist()
# Compute the ensemble loss dependent on the task
if config['task'] == 'prosaccade-clf':
print("append prosaccade loss")
loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
elif config['task'] == 'angle-reg':
print("append angle loss")
loss.append(angle_loss(pred_ensemble.targets,pred_epoch).numpy())
elif config['task'] == 'gaze-reg':
print("append mse loss")
loss.append(mse(pred_ensemble.targets,pred_epoch).numpy())
else:
raise Exception("Choose valid task in config.py")
print(f"compute loss epoch number {j}")
#pred_epoch = (pred_epoch/config['ensemble']).tolist()
loss.append(compute_loss(loss_fn=self.loss_fn, dataloader=test_dataloader, pred_list=pred_epoch, nb_models=config['ensemble']))
pred_epoch = np.round(pred_epoch,0)
if config['task'] == 'prosaccade-clf':
accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
#accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
accuracy.append(compute_accuracy(dataloader=test_dataloader, pred_list=pred_epoch, nb_models=config['ensemble']))
if config['ensemble']>1:
config['model']+='_ensemble'
......@@ -93,16 +107,8 @@ class Ensemble_torch:
#save_logs(hist, config['model_dir'], config['model'], pytorch = False)
def predict(self, X):
"""
Predict with all models on the dataset X
Return the average prediction of all models in self.models
"""
for i in range(len(self.models)):
if i == 0:
pred = self.models[i].model.predict(X)
else:
pred += self.models[i].get_model().predict(X)
return pred / len(self.models)
#TODO: implement for torch dataloader
pass
def _build_ensemble_model(self):
"""
......@@ -132,6 +138,10 @@ class Ensemble_torch:
pred += self.models[i].model.predict(X)
return pred / len(self.models)
# optional
def load_models(self, path_to_models):
pass
def create_model(model_type, model_number):
"""
......@@ -141,6 +151,9 @@ def create_model(model_type, model_number):
model = CNN(input_shape=config['cnn']['input_shape'], kernel_size=64, epochs = config['epochs'],
nb_filters=16, verbose=True, batch_size=64, use_residual=True, depth=12,
model_number=model_number)
elif model_type == 'inception':
model = Inception(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16, epochs=config['epochs'])
#elif model_type == 'eegnet':
# model = EEGNet(input_shape=(config['eegnet']['samples'], config['eegnet']['channels']),
# model_number=model_number, epochs=config['epochs'])
......@@ -152,9 +165,6 @@ def create_model(model_type, model_number):
elif model_type == 'pyramidal_cnn':
model = PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=config['epochs'],
model_number=model_number)
elif model_type == 'inception':
model = INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16, epochs=config['epochs'])
elif model_type == 'xception':
model = XCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
kernel_size=40, nb_filters=64, depth=18, epochs=config['epochs'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment