BaseNetTorch.py 5.33 KB
Newer Older
1
2
3
4
import torch
from torch import nn
import numpy as np
from config import config 
Lukas Wolf's avatar
Lukas Wolf committed
5
import logging
Lukas Wolf's avatar
Lukas Wolf committed
6
from torch_models.torch_utils.training import train_loop, test_loop
Lukas Wolf's avatar
Lukas Wolf committed
7
8
9
10
11
from torch_models.torch_utils.utils import get_gpu_memory
import psutil
from torch_models.torch_utils.utils import timing_decorator
from memory_profiler import profile

12

Lukas Wolf's avatar
Lukas Wolf committed
13
class Prediction_history:
14
    """
Lukas Wolf's avatar
Lukas Wolf committed
15
    Collect predictions of the given validation set after each epoch 
Lukas Wolf's avatar
Lukas Wolf committed
16
    predhis is a list of lists (one for each epoch) of tensors (one for each batch)
17
    """
Lukas Wolf's avatar
Lukas Wolf committed
18
    def __init__(self, dataloader, device, model) -> None:
Lukas Wolf's avatar
Lukas Wolf committed
19
        self.dataloader = dataloader
20
        self.predhis = []
Lukas Wolf's avatar
Lukas Wolf committed
21
22
        self.device = device
        self.model = model
23

Lukas Wolf's avatar
Lukas Wolf committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    #@timing_decorator
    #@profile
    def on_epoch_end(self):
        with torch.no_grad():
            y_pred = []
            for x, y in self.dataloader:
                # Move batch to GPU
                if torch.cuda.is_available():
                    x = x.cuda()
                    y = y.cuda()
                y_pred.append(self.model(x)) 
                # Remove batch from GPU 
                del x
                del y 
                #torch.cuda.empty_cache()
            self.predhis.append(y_pred)
40
41

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
42
43
44
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """
45
    def __init__(self, input_shape, epochs=50, verbose=True, model_number=0, batch_size=64):
Lukas Wolf's avatar
Lukas Wolf committed
46
47
48
49
        """
        Initialize common variables of models based on BaseNet
        Create the common output layer dependent on the task to run 
        """
50
        super().__init__()
51
        self.input_shape = input_shape
52
53
54
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number
55
56
57
        self.batch_size = batch_size
        self.nb_channels = self.input_shape[1]
        self.timesamples = self.input_shape[0]
Lukas Wolf's avatar
Lukas Wolf committed
58
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59

60
61
62
63
        # Create output layer depending on task and
        if config['task'] == 'prosaccade_clf':
            self.loss_fn = nn.BCELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
64
65
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1),
                nn.Sigmoid()
66
67
68
69
            )
        elif config['task'] == 'gaze-reg':
            self.loss_fn = nn.MSELoss()
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
70
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=2) 
71
72
73
74
75
            )
        else: #elif config['task'] == 'angle-reg':
            from torch_models.torch_utils.custom_losses import angle_loss
            self.loss_fn = angle_loss
            self.output_layer = nn.Sequential(
Lukas Wolf's avatar
Lukas Wolf committed
76
                nn.Linear(in_features=self.get_nb_features_output_layer(), out_features=1) 
77
            )
78
79
80

    # abstract method 
    def forward(self, x):
Lukas Wolf's avatar
Lukas Wolf committed
81
82
83
84
85
86
87
88
89
90
        """
        Implements a forward pass of the network 
        """
        pass
    
    # abstract method 
    def get_nb_features_output_layer(self):
        """
        Return the number of features that the output layer should take as input
        """
91
        pass
Lukas Wolf's avatar
Lukas Wolf committed
92

93
94
    # abstract method
    def _split_model(self):
Lukas Wolf's avatar
Lukas Wolf committed
95
96
97
98
        pass    
    
    @profile
    @timing_decorator
Lukas Wolf's avatar
Lukas Wolf committed
99
100
101
102
    def fit(self, train_dataloader, test_dataloader, subjectID=None):
        """
        Fit the model on the dataset defined by data x and labels y 
        """
103
104
        logging.info("------------------------------------------------------------------------------------")
        logging.info(f"Fitting model number {self.model_number}")
Lukas Wolf's avatar
Lukas Wolf committed
105
106
107
108
        # Move the model to GPU
        if torch.cuda.is_available():
            self.cuda()
            logging.info(f"Model moved to cuda")
Lukas Wolf's avatar
Lukas Wolf committed
109
110
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])
Lukas Wolf's avatar
Lukas Wolf committed
111
        # Create a history to track ensemble performance 
Lukas Wolf's avatar
Lukas Wolf committed
112
        prediction_ensemble = Prediction_history(dataloader=test_dataloader, device=self.device, model=self)
Lukas Wolf's avatar
Lukas Wolf committed
113
        # Train the model 
114
        epochs = config['epochs']             
Lukas Wolf's avatar
Lukas Wolf committed
115
        for t in range(epochs):
116
            logging.info(f"Epoch {t+1}\n-------------------------------")
Lukas Wolf's avatar
Lukas Wolf committed
117
118
            print(f"Start EPOCH: Free GPU memory: {get_gpu_memory()}")
            print(f"memory {psutil.virtual_memory()}") 
Lukas Wolf's avatar
Lukas Wolf committed
119
            # Run through training and test set 
Lukas Wolf's avatar
Lukas Wolf committed
120
121
122
123
124
125
126
            print(self.device)
            train_loop(train_dataloader, self.float(), self.loss_fn, optimizer, self.device)
            print(f"Free GPU mem after train loop: {get_gpu_memory()}")
            print(f"memory {psutil.virtual_memory()}") 
            test_loop(test_dataloader, self.float(), self.loss_fn, self.device)
            print("Free GPU mem after test loop:")
            print(f"memory {psutil.virtual_memory()}") 
Lukas Wolf's avatar
Lukas Wolf committed
127
            # Add the predictions on the validation set
Lukas Wolf's avatar
Lukas Wolf committed
128
129
130
131
            prediction_ensemble.on_epoch_end()
            print("Free GPU mem after prediction hist:")
            print(f"memory {psutil.virtual_memory()}") 
        # Done with training this model 
132
133
        logging.info(f"Finished model number {self.model_number}")
        if config['save_models'] and self.model_number==0:
Lukas Wolf's avatar
Lukas Wolf committed
134
135
            ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
            torch.save(self.state_dict(), ckpt_dir)
136
        return prediction_ensemble