BaseNetTorch.py 3.73 KB
Newer Older
1
2
3
4
5
import torch
from torch import nn
import numpy as np
from config import config 
import logging 
6

Lukas Wolf's avatar
Lukas Wolf committed
7
from torch_models.torch_utils.training import train_loop, test_loop
8

Lukas Wolf's avatar
Lukas Wolf committed
9
class Prediction_history:
10
    """
Lukas Wolf's avatar
Lukas Wolf committed
11
    Collect predictions of the given validation set after each epoch 
Lukas Wolf's avatar
Lukas Wolf committed
12
    predhis is a list of lists (one for each epoch) of tensors (one for each batch)
13
    """
Lukas Wolf's avatar
Lukas Wolf committed
14
15
    def __init__(self, dataloader) -> None:
        self.dataloader = dataloader
16
17
18
        self.predhis = []

    def on_epoch_end(self, model):
Lukas Wolf's avatar
Lukas Wolf committed
19
20
21
22
        y_pred = []
        for x, y in self.dataloader:
            y_pred.append(model(x.float()))
        self.predhis.append(y_pred)
23
24

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
25
26
27
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """
28
    def __init__(self, input_shape, epochs=50, verbose=True, model_number=0, batch_size=64):
Lukas Wolf's avatar
Lukas Wolf committed
29
30
31
32
        """
        Initialize common variables of models based on BaseNet
        Create the common output layer dependent on the task to run 
        """
33
        super().__init__()
34
        self.input_shape = input_shape
35
36
37
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number
38
39
40
        self.batch_size = batch_size
        self.nb_channels = self.input_shape[1]
        self.timesamples = self.input_shape[0]
41

42
43
44
45
        # Create output layer depending on task and
        if config['task'] == 'prosaccade_clf':
            self.loss_fn = nn.BCELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
46
47
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1),
                nn.Sigmoid()
48
49
50
51
            )
        elif config['task'] == 'gaze-reg':
            self.loss_fn = nn.MSELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
52
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=2) 
53
54
55
56
57
            )
        else: #elif config['task'] == 'angle-reg':
            from torch_models.torch_utils.custom_losses import angle_loss
            self.loss_fn = angle_loss
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
58
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1) 
59
            )
60
61
62

    # abstract method 
    def forward(self, x):
Lukas Wolf's avatar
Lukas Wolf committed
63
64
65
66
67
68
69
70
71
72
        """
        Implements a forward pass of the network 
        """
        pass
    
    # abstract method 
    def get_nb_features_output_layer(self):
        """
        Return the number of features that the output layer should take as input
        """
73
        pass
Lukas Wolf's avatar
Lukas Wolf committed
74

75
76
77
    # abstract method
    def _split_model(self):
        pass
Lukas Wolf's avatar
Lukas Wolf committed
78
        
Lukas Wolf's avatar
Lukas Wolf committed
79
80
81
82
83
84
85
    def fit(self, train_dataloader, test_dataloader, subjectID=None):
        """
        Fit the model on the dataset defined by data x and labels y 
        """
        print("------------------------------------------------------------------------------------")
        print(f"Fitting model number {self.model_number}")
        
Lukas Wolf's avatar
Lukas Wolf committed
86
87
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
88
        # Create history and log 
Lukas Wolf's avatar
Lukas Wolf committed
89
        prediction_ensemble = Prediction_history(dataloader=test_dataloader)
Lukas Wolf's avatar
Lukas Wolf committed
90
        # Train the model 
91
        epochs = config['epochs']             
Lukas Wolf's avatar
Lukas Wolf committed
92
        for t in range(epochs):
Lukas Wolf's avatar
Lukas Wolf committed
93
            print(f"Epoch {t+1}\n-------------------------------")
94
95
96
            train_loop(train_dataloader, self.float(), self.loss_fn, optimizer)
            test_loop(test_dataloader, self.float(), self.loss_fn)
            prediction_ensemble.on_epoch_end(model=self)
Lukas Wolf's avatar
Lukas Wolf committed
97
        print(f"Finished model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
98
        # Save model 
Lukas Wolf's avatar
Lukas Wolf committed
99
100
        ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
        torch.save(self.state_dict(), ckpt_dir)
101
        return prediction_ensemble