BaseNetTorch.py 4.35 KB
Newer Older
1
2
3
4
5
import torch
from torch import nn
import numpy as np
from config import config 
import logging 
Lukas Wolf's avatar
Lukas Wolf committed
6
from torch.utils.tensorboard import SummaryWriter
7

Lukas Wolf's avatar
Lukas Wolf committed
8
from torch_models.torch_utils.training import train_loop, test_loop
9

Lukas Wolf's avatar
Lukas Wolf committed
10
class Prediction_history:
11
    """
Lukas Wolf's avatar
Lukas Wolf committed
12
    Collect predictions of the given validation set after each epoch 
Lukas Wolf's avatar
Lukas Wolf committed
13
    predhis is a list of lists (one for each epoch) of tensors (one for each batch)
14
    """
Lukas Wolf's avatar
Lukas Wolf committed
15
16
    def __init__(self, dataloader) -> None:
        self.dataloader = dataloader
17
18
19
        self.predhis = []

    def on_epoch_end(self, model):
Lukas Wolf's avatar
Lukas Wolf committed
20
21
22
23
        y_pred = []
        for x, y in self.dataloader:
            y_pred.append(model(x.float()))
        self.predhis.append(y_pred)
24
25

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
26
27
28
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """
29
    def __init__(self, input_shape, epochs=50, verbose=True, model_number=0, batch_size=64):
Lukas Wolf's avatar
Lukas Wolf committed
30
31
32
33
        """
        Initialize common variables of models based on BaseNet
        Create the common output layer dependent on the task to run 
        """
34
        super().__init__()
35
        self.input_shape = input_shape
36
37
38
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number
39
40
41
        self.batch_size = batch_size
        self.nb_channels = self.input_shape[1]
        self.timesamples = self.input_shape[0]
42

43
44
45
46
        # Create output layer depending on task and
        if config['task'] == 'prosaccade_clf':
            self.loss_fn = nn.BCELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
47
48
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1),
                nn.Sigmoid()
49
50
51
52
            )
        elif config['task'] == 'gaze-reg':
            self.loss_fn = nn.MSELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
53
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=2) 
54
55
56
57
58
            )
        else: #elif config['task'] == 'angle-reg':
            from torch_models.torch_utils.custom_losses import angle_loss
            self.loss_fn = angle_loss
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
59
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1) 
60
            )
61
62
63

    # abstract method 
    def forward(self, x):
Lukas Wolf's avatar
Lukas Wolf committed
64
65
66
67
68
69
70
71
72
73
        """
        Implements a forward pass of the network 
        """
        pass
    
    # abstract method 
    def get_nb_features_output_layer(self):
        """
        Return the number of features that the output layer should take as input
        """
74
        pass
Lukas Wolf's avatar
Lukas Wolf committed
75

76
77
78
    # abstract method
    def _split_model(self):
        pass
Lukas Wolf's avatar
Lukas Wolf committed
79
        
Lukas Wolf's avatar
Lukas Wolf committed
80
81
82
83
84
85
86
    def fit(self, train_dataloader, test_dataloader, subjectID=None):
        """
        Fit the model on the dataset defined by data x and labels y 
        """
        print("------------------------------------------------------------------------------------")
        print(f"Fitting model number {self.model_number}")
        
Lukas Wolf's avatar
Lukas Wolf committed
87
88
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
Lukas Wolf's avatar
Lukas Wolf committed
89
        # Create a history to track ensemble performance 
Lukas Wolf's avatar
Lukas Wolf committed
90
        prediction_ensemble = Prediction_history(dataloader=test_dataloader)
Lukas Wolf's avatar
Lukas Wolf committed
91
92
        # Create a summary writer for logging metrics 
        writer = SummaryWriter(log_dir=config['model_dir']+'/summary_writer')
Lukas Wolf's avatar
Lukas Wolf committed
93
        # Train the model 
94
        epochs = config['epochs']             
Lukas Wolf's avatar
Lukas Wolf committed
95
        for t in range(epochs):
Lukas Wolf's avatar
Lukas Wolf committed
96
            print(f"Epoch {t+1}\n-------------------------------")
Lukas Wolf's avatar
Lukas Wolf committed
97
98
99
100
            # Run through training and test set 
            train_loss = train_loop(train_dataloader, self.float(), self.loss_fn, optimizer)
            test_loss, test_acc = test_loop(test_dataloader, self.float(), self.loss_fn)
            # Add the predictions on the validation set
101
            prediction_ensemble.on_epoch_end(model=self)
Lukas Wolf's avatar
Lukas Wolf committed
102
103
104
105
106
            # Log metrics to the writer 
            writer.add_scalar('Loss/train', train_loss, t)
            writer.add_scalar('Loss/test', test_loss, t)
            if config['task'] == 'prosaccade-clf':
                writer.add_scalar('Accuracy/test', test_acc, t)
Lukas Wolf's avatar
Lukas Wolf committed
107
        print(f"Finished model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
108
109
110
        if config['save_model']:
            ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
            torch.save(self.state_dict(), ckpt_dir)
111
        return prediction_ensemble