BaseNetTorch.py 4.46 KB
Newer Older
1
2
3
4
5
import torch
from torch import nn
import numpy as np
from config import config 
import logging 
Lukas Wolf's avatar
Lukas Wolf committed
6
from sklearn.model_selection import train_test_split
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from torch_models.torch_utils.dataloader import create_dataloader
from torch_models.torch_utils.training import train_loop, test_loop


class Prediction_history():
    """
    Prediction history for pytorch model ensembles 
    """
    def __init__(self, X_val, y_val) -> None:
        # Create tensor 
        self.X_val = torch.tensor(X_val)
        self.y_val = torch.tensor(y_val)
        self.X_val.cuda() 
        self.y_val.cuda()
        self.predhis = []

    def on_epoch_end(self, model):
        y_pred = model(self.X_val.float())
        # Transform back to numpy array because ensemble handles it that way 
        self.predhis.append(y_pred.numpy())

28
29

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
30
31
32
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """
33
    def __init__(self, input_shape, epochs=50, verbose=True, model_number=0, batch_size=64):
Lukas Wolf's avatar
Lukas Wolf committed
34

35
        super().__init__()
36
        self.input_shape = input_shape
37
38
39
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number
40
41
42
        self.batch_size = batch_size
        self.nb_channels = self.input_shape[1]
        self.timesamples = self.input_shape[0]
43

44
45
46
47
48
49
50
51
52
53
54
        # Set the number of features that are passed throught the internal network (except input layer)
        if config['model'] == 'cnn':
            self.num_features = 16
        elif config['model'] == 'deepeye':
            self.num_features = 164 
        else: # all other current models have tensors of width 64
            self.num_features = 64
        
        # Compute the number of features for the output layer
        eegNet_out = 4*2*7
        convNet_out = self.num_features * self.timesamples
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        # Create output layer depending on task and
        if config['task'] == 'prosaccade_clf':
            self.loss_fn = nn.BCELoss()
            self.output_layer = nn.Sequential(
                nn.Linear(in_features=eegNet_out if config['model'] == 'eegnet' else convNet_out, out_features=1) 
            )
        elif config['task'] == 'gaze-reg':
            self.loss_fn = nn.MSELoss()
            self.output_layer = nn.Sequential(
                nn.Linear(in_features=eegNet_out if config['model'] == 'eegnet' else convNet_out, out_features=2) 
            )
        else: #elif config['task'] == 'angle-reg':
            from torch_models.torch_utils.custom_losses import angle_loss
            self.loss_fn = angle_loss
            self.output_layer = nn.Sequential(
                nn.Linear(in_features=eegNet_out if config['model'] == 'eegnet' else convNet_out, out_features=1) 
            )
73
74
75
76
77

    # abstract method 
    def forward(self, x):
        pass

Lukas Wolf's avatar
Lukas Wolf committed
78
79
80
    def get_model(self):
        return self 

81
82
83
    # abstract method
    def _split_model(self):
        pass
Lukas Wolf's avatar
Lukas Wolf committed
84
        
85
86
    # abstract method
    def _build_model(self):
Lukas Wolf's avatar
Lukas Wolf committed
87
88
89
        pass

    def fit(self, x, y, subjectID=None):
90
91
        logging.info("------------------------------------------------------------------------------------")
        logging.info(f"Fitting model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
92
        # Create a split
93
        x = np.transpose(x, (0, 2, 1)) # (batch_size, samples, channels) to (bs, ch, samples) as torch conv layers want it 
Lukas Wolf's avatar
Lukas Wolf committed
94
95
96
97
98
99
        X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
        # Create dataloaders
        train_dataloader = create_dataloader(X_train, y_train, batch_size=config['batch_size'])
        test_dataloader = create_dataloader(X_val, y_val, batch_size=config['batch_size'])
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
100
101
        # Create history and log 
        prediction_ensemble = Prediction_history(X_val, y_val)
Lukas Wolf's avatar
Lukas Wolf committed
102
        # Train the model 
103
        epochs = config['epochs']             
Lukas Wolf's avatar
Lukas Wolf committed
104
105
        for t in range(epochs):
            logging.info(f"Epoch {t+1}\n-------------------------------")
106
107
108
109
            train_loop(train_dataloader, self.float(), self.loss_fn, optimizer)
            test_loop(test_dataloader, self.float(), self.loss_fn)
            prediction_ensemble.on_epoch_end(model=self)
        logging.info(f"Finished model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
110
111

        # Save model 
112
        ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.h5'
Lukas Wolf's avatar
Lukas Wolf committed
113
114
        torch.save(self, ckpt_dir)

115
        return prediction_ensemble