BaseNetTorch.py 3.79 KB
Newer Older
1
2
3
4
5
import torch
from torch import nn
import numpy as np
from config import config 
import logging 
6

Lukas Wolf's avatar
Lukas Wolf committed
7
from torch_models.torch_utils.training import train_loop, test_loop
8

Lukas Wolf's avatar
Lukas Wolf committed
9
class Prediction_history:
10
    """
Lukas Wolf's avatar
Lukas Wolf committed
11
    Collect predictions of the given validation set after each epoch 
12
    """
Lukas Wolf's avatar
Lukas Wolf committed
13
14
    def __init__(self, dataloader) -> None:
        self.dataloader = dataloader
15
16
17
        self.predhis = []

    def on_epoch_end(self, model):
Lukas Wolf's avatar
Lukas Wolf committed
18
19
20
21
22
23
24
        """
        When epoch ends predict the validation set and store it in predhis=[ypred_epoch_1, ypred_epoch_2,...]
        """
        y_pred = []
        for x, y in self.dataloader:
            y_pred.append(model(x.float()))
        self.predhis.append(y_pred)
25
26

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
27
28
29
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """
30
    def __init__(self, input_shape, epochs=50, verbose=True, model_number=0, batch_size=64):
Lukas Wolf's avatar
Lukas Wolf committed
31
32
33
34
        """
        Initialize common variables of models based on BaseNet
        Create the common output layer dependent on the task to run 
        """
35
        super().__init__()
36
        self.input_shape = input_shape
37
38
39
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number
40
41
42
        self.batch_size = batch_size
        self.nb_channels = self.input_shape[1]
        self.timesamples = self.input_shape[0]
43

44
45
46
47
        # Create output layer depending on task and
        if config['task'] == 'prosaccade_clf':
            self.loss_fn = nn.BCELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
48
49
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1),
                nn.Sigmoid()
50
51
52
53
            )
        elif config['task'] == 'gaze-reg':
            self.loss_fn = nn.MSELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
54
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=2) 
55
56
57
58
59
            )
        else: #elif config['task'] == 'angle-reg':
            from torch_models.torch_utils.custom_losses import angle_loss
            self.loss_fn = angle_loss
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
60
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1) 
61
            )
62
63
64

    # abstract method 
    def forward(self, x):
Lukas Wolf's avatar
Lukas Wolf committed
65
66
67
68
69
70
71
72
73
74
        """
        Implements a forward pass of the network 
        """
        pass
    
    # abstract method 
    def get_nb_features_output_layer(self):
        """
        Return the number of features that the output layer should take as input
        """
75
        pass
Lukas Wolf's avatar
Lukas Wolf committed
76

77
78
79
    # abstract method
    def _split_model(self):
        pass
Lukas Wolf's avatar
Lukas Wolf committed
80
        
Lukas Wolf's avatar
Lukas Wolf committed
81
82
83
84
85
86
87
    def fit(self, train_dataloader, test_dataloader, subjectID=None):
        """
        Fit the model on the dataset defined by data x and labels y 
        """
        print("------------------------------------------------------------------------------------")
        print(f"Fitting model number {self.model_number}")
        
Lukas Wolf's avatar
Lukas Wolf committed
88
89
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
90
        # Create history and log 
Lukas Wolf's avatar
Lukas Wolf committed
91
        prediction_ensemble = Prediction_history(dataloader=test_dataloader)
Lukas Wolf's avatar
Lukas Wolf committed
92
        # Train the model 
93
        epochs = config['epochs']             
Lukas Wolf's avatar
Lukas Wolf committed
94
95
        for t in range(epochs):
            logging.info(f"Epoch {t+1}\n-------------------------------")
96
97
98
            train_loop(train_dataloader, self.float(), self.loss_fn, optimizer)
            test_loop(test_dataloader, self.float(), self.loss_fn)
            prediction_ensemble.on_epoch_end(model=self)
Lukas Wolf's avatar
Lukas Wolf committed
99
        print(f"Finished model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
100
        # Save model 
Lukas Wolf's avatar
Lukas Wolf committed
101
102
        ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
        torch.save(self.state_dict(), ckpt_dir)
103
        return prediction_ensemble