BaseNetTorch.py 4.45 KB
Newer Older
1
2
3
4
5
import torch
from torch import nn
import numpy as np
from config import config 
import logging 
6
#from torch.utils.tensorboard import SummaryWriter
Lukas Wolf's avatar
Lukas Wolf committed
7
from torch_models.torch_utils.training import train_loop, test_loop
8

Lukas Wolf's avatar
Lukas Wolf committed
9
class Prediction_history:
10
    """
Lukas Wolf's avatar
Lukas Wolf committed
11
    Collect predictions of the given validation set after each epoch 
Lukas Wolf's avatar
Lukas Wolf committed
12
    predhis is a list of lists (one for each epoch) of tensors (one for each batch)
13
    """
Lukas Wolf's avatar
Lukas Wolf committed
14
15
    def __init__(self, dataloader) -> None:
        self.dataloader = dataloader
16
17
18
        self.predhis = []

    def on_epoch_end(self, model):
Lukas Wolf's avatar
Lukas Wolf committed
19
20
21
22
        y_pred = []
        for x, y in self.dataloader:
            y_pred.append(model(x.float()))
        self.predhis.append(y_pred)
23
24

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
25
26
27
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """
28
    def __init__(self, input_shape, epochs=50, verbose=True, model_number=0, batch_size=64):
Lukas Wolf's avatar
Lukas Wolf committed
29
30
31
32
        """
        Initialize common variables of models based on BaseNet
        Create the common output layer dependent on the task to run 
        """
33
        super().__init__()
34
        self.input_shape = input_shape
35
36
37
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number
38
39
40
        self.batch_size = batch_size
        self.nb_channels = self.input_shape[1]
        self.timesamples = self.input_shape[0]
41

42
43
44
45
        # Create output layer depending on task and
        if config['task'] == 'prosaccade_clf':
            self.loss_fn = nn.BCELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
46
47
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1),
                nn.Sigmoid()
48
49
50
51
            )
        elif config['task'] == 'gaze-reg':
            self.loss_fn = nn.MSELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
52
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=2) 
53
54
55
56
57
            )
        else: #elif config['task'] == 'angle-reg':
            from torch_models.torch_utils.custom_losses import angle_loss
            self.loss_fn = angle_loss
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
58
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1) 
59
            )
60
61
62

    # abstract method 
    def forward(self, x):
Lukas Wolf's avatar
Lukas Wolf committed
63
64
65
66
67
68
69
70
71
72
        """
        Implements a forward pass of the network 
        """
        pass
    
    # abstract method 
    def get_nb_features_output_layer(self):
        """
        Return the number of features that the output layer should take as input
        """
73
        pass
Lukas Wolf's avatar
Lukas Wolf committed
74

75
76
77
    # abstract method
    def _split_model(self):
        pass
Lukas Wolf's avatar
Lukas Wolf committed
78
        
Lukas Wolf's avatar
Lukas Wolf committed
79
80
81
82
    def fit(self, train_dataloader, test_dataloader, subjectID=None):
        """
        Fit the model on the dataset defined by data x and labels y 
        """
83
84
        logging.info("------------------------------------------------------------------------------------")
        logging.info(f"Fitting model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
85
        
Lukas Wolf's avatar
Lukas Wolf committed
86
87
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
Lukas Wolf's avatar
Lukas Wolf committed
88
        # Create a history to track ensemble performance 
Lukas Wolf's avatar
Lukas Wolf committed
89
        prediction_ensemble = Prediction_history(dataloader=test_dataloader)
Lukas Wolf's avatar
Lukas Wolf committed
90
        # Create a summary writer for logging metrics 
91
        #writer = SummaryWriter(log_dir=config['model_dir']+'/summary_writer')
Lukas Wolf's avatar
Lukas Wolf committed
92
        # Train the model 
93
        epochs = config['epochs']             
Lukas Wolf's avatar
Lukas Wolf committed
94
        for t in range(epochs):
95
            logging.info(f"Epoch {t+1}\n-------------------------------")
Lukas Wolf's avatar
Lukas Wolf committed
96
97
98
99
            # Run through training and test set 
            train_loss = train_loop(train_dataloader, self.float(), self.loss_fn, optimizer)
            test_loss, test_acc = test_loop(test_dataloader, self.float(), self.loss_fn)
            # Add the predictions on the validation set
100
            prediction_ensemble.on_epoch_end(model=self)
101
            logging.info("end epoch")
Lukas Wolf's avatar
Lukas Wolf committed
102
            # Log metrics to the writer 
103
104
105
106
107
108
            #writer.add_scalar('Loss/train', train_loss, t)
            #writer.add_scalar('Loss/test', test_loss, t)
            #if config['task'] == 'prosaccade-clf':
                #writer.add_scalar('Accuracy/test', test_acc, t)
        logging.info(f"Finished model number {self.model_number}")
        if config['save_models'] and self.model_number==0:
Lukas Wolf's avatar
Lukas Wolf committed
109
110
            ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
            torch.save(self.state_dict(), ckpt_dir)
111
        return prediction_ensemble