training.py 2.94 KB
Newer Older
Lukas Wolf's avatar
Lukas Wolf committed
1
2
3
import logging
from config import config
import torch
Lukas Wolf's avatar
Lukas Wolf committed
4
from torch import nn 
Lukas Wolf's avatar
Lukas Wolf committed
5
6
7
8
from torch_models.torch_utils.utils import get_gpu_memory
from torch_models.torch_utils.utils import timing_decorator
from memory_profiler import profile
#import torch.profiler
9

Lukas Wolf's avatar
Lukas Wolf committed
10
#@timing_decorator
Lukas Wolf's avatar
Lukas Wolf committed
11
#@profile
Lukas Wolf's avatar
Lukas Wolf committed
12
def train_loop(dataloader, model, loss_fn, optimizer):
Lukas Wolf's avatar
Lukas Wolf committed
13
14
15
    """
    Performs one epoch of training the model through the dataset stored in dataloader
    Using the given loss_fn and optimizer
Lukas Wolf's avatar
Lukas Wolf committed
16
    Returns training loss of the epoch to be tracked by the caller
Lukas Wolf's avatar
Lukas Wolf committed
17
    """
Lukas Wolf's avatar
Lukas Wolf committed
18
19
    #get_gpu_memory()
    #print(torch.cuda.memory_summary())
20
    size = len(dataloader.dataset)
Lukas Wolf's avatar
Lukas Wolf committed
21
    training_loss, correct = 0, 0
22
    for batch, (X, y) in enumerate(dataloader):
Lukas Wolf's avatar
Lukas Wolf committed
23
        # Move tensors to GPU
Lukas Wolf's avatar
Lukas Wolf committed
24
        if torch.cuda.is_available():
Lukas Wolf's avatar
Lukas Wolf committed
25
26
            X = X.cuda()
            y = y.cuda()
27
        # Compute prediction and loss
Lukas Wolf's avatar
Lukas Wolf committed
28
        pred = model(X)
Lukas Wolf's avatar
Lukas Wolf committed
29
30
        loss = loss_fn(pred, y)
        # Backpropagation and optimization 
31
32
33
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Lukas Wolf's avatar
Lukas Wolf committed
34
35
        # Add up metrics
        training_loss += loss.item()
Lukas Wolf's avatar
Lukas Wolf committed
36
37
38
39
40
41
42
        if config['task'] == 'prosaccade-clf':
            pred = (pred > 0.5).float()
            correct += (pred == y).float().sum()
        # Remove batch from gpu
        del X
        del y 
        torch.cuda.empty_cache()
Lukas Wolf's avatar
Lukas Wolf committed
43

44
45
46
47
48
    loss = training_loss / size 
    logging.info(f"Avg training loss: {loss:>7f}")
    if config['task'] == 'prosaccade-clf':  
        accuracy = correct / size           
        logging.info(f"Avg training accuracy {accuracy:>8f}")
Lukas Wolf's avatar
Lukas Wolf committed
49
50
        return float(loss), float(accuracy) 
    return float(loss), -1 
51

Lukas Wolf's avatar
Lukas Wolf committed
52
#@timing_decorator
Lukas Wolf's avatar
Lukas Wolf committed
53
#@profile
54
def validation_loop(dataloader, model, loss_fn):
Lukas Wolf's avatar
Lukas Wolf committed
55
56
57
58
    """
    Performs one prediction run through the test set stored in the dataloader
    Prints the loss function computed with the prediction pred and the labels y
    """
Lukas Wolf's avatar
Lukas Wolf committed
59
60
61
62
    #print("Enter validation:")
    #get_gpu_memory()
    #print(torch.cuda.memory_summary())

63
    size = len(dataloader.dataset)
64
    val_loss, correct = 0, 0
65
    with torch.no_grad():
Lukas Wolf's avatar
Lukas Wolf committed
66
        for batch, (X, y) in enumerate(dataloader):
Lukas Wolf's avatar
Lukas Wolf committed
67
            # Move tensors to GPU
Lukas Wolf's avatar
Lukas Wolf committed
68
            if torch.cuda.is_available():
Lukas Wolf's avatar
Lukas Wolf committed
69
70
71
72
73
                X = X.cuda()
                y = y.cuda()
            # Predict 
            pred = model(X)
            # Compute metrics 
74
            val_loss += loss_fn(pred, y).item()
Lukas Wolf's avatar
Lukas Wolf committed
75
76
77
            if config['task'] == 'prosaccade-clf':
                pred = (pred > 0.5).float()
                correct += (pred == y).float().sum() 
Lukas Wolf's avatar
Lukas Wolf committed
78
79
80
81
            # Remove batch from gpu
            del X
            del y 
            torch.cuda.empty_cache()
Lukas Wolf's avatar
Lukas Wolf committed
82
    
83
84
    loss = val_loss / size
    logging.info(f"Avg validation loss: {loss:>8f}")
Lukas Wolf's avatar
Lukas Wolf committed
85
    if config['task'] == 'prosaccade-clf':
86
87
        accuracy = correct / size
        logging.info(f"Avg validation accuracy {accuracy:>8f}")
Lukas Wolf's avatar
Lukas Wolf committed
88
89
        return float(loss), float(accuracy)
    return float(loss), -1 # Can be used for early stopping 
Lukas Wolf's avatar
Lukas Wolf committed
90