BaseNetTorch.py 2.59 KB
Newer Older
1
2
3
4
5
import torch
from torch import nn
import numpy as np
from config import config 
import logging 
Lukas Wolf's avatar
Lukas Wolf committed
6
7
8
from sklearn.model_selection import train_test_split
from torch_models.utils.dataloader import create_dataloader
from torch_models.utils.training import train_loop, test_loop
9
10

class BaseNet(nn.Module):
Lukas Wolf's avatar
Lukas Wolf committed
11
12
13
14
    """
    BaseNet class for ConvNet and EEGnet to inherit common functionality 
    """

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    def __init__(self, epochs=50, verbose=True, model_number=0):
        super().__init__()
        self.epochs = epochs
        self.verbose = verbose
        self.model_number = model_number

        if self.verbose:
            print(self) # works for torch 

        # Get cpu or gpu device for training.
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logging.info("Using {} device".format(self.device))

    # abstract method 
    def forward(self, x):
        pass

Lukas Wolf's avatar
Lukas Wolf committed
32
33
34
    def get_model(self):
        return self 

35
36
37
    # abstract method
    def _split_model(self):
        pass
Lukas Wolf's avatar
Lukas Wolf committed
38
        
39
40
    # abstract method
    def _build_model(self):
Lukas Wolf's avatar
Lukas Wolf committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        pass

    def fit(self, x, y, subjectID=None):
        logging.info(f"Fiting model {self.__name__}, model number {self.model_number}")
        # Create a split
        X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
        # Create dataloaders
        train_dataloader = create_dataloader(X_train, y_train, batch_size=config['batch_size'])
        test_dataloader = create_dataloader(X_val, y_val, batch_size=config['batch_size'])
        logging.info(f"Created train dataloader with x shape{x[0].shape} and y shape {y[0].shape}")

        # Create the training depending on the task 
        if config['task'] == 'prosaccade-clf':
            loss_fn = nn.BCELoss()
        elif config['task'] == 'gaze-reg':
            loss_fn = nn.MSELoss()
        elif config['task'] == 'angle-reg':
            from torch_models.utils.custom_losses import angle_loss
        # Create the optimizer
        optimizer = torch.optim.Adam(list(self.parameters()), lr=config['learning_rate'])

        # Train the model 
        epochs = config['epochs']
        for t in range(epochs):
            logging.info(f"Epoch {t+1}\n-------------------------------")
            train_loop(train_dataloader, self, loss_fn, optimizer)
            test_loop(test_dataloader, self, loss_fn)
        logging.info(f"Finished model {self.__name__}, model number {self.model_number}")

        # Save model 
        ckpt_dir = config['model_dir'] + '/best_models/' + self.__str__ + '_nb_{}_'.format(self.model_number) + 'best_model.h5'
        torch.save(self, ckpt_dir)