import logging from config import config import torch from torch import nn from torch_models.torch_utils.utils import get_gpu_memory from torch_models.torch_utils.utils import timing_decorator from memory_profiler import profile #import torch.profiler #@timing_decorator #@profile def train_loop(dataloader, model, loss_fn, optimizer): """ Performs one epoch of training the model through the dataset stored in dataloader Using the given loss_fn and optimizer Returns training loss of the epoch to be tracked by the caller """ #get_gpu_memory() #print(torch.cuda.memory_summary()) size = len(dataloader.dataset) training_loss, correct = 0, 0 for batch, (X, y) in enumerate(dataloader): # Move tensors to GPU if torch.cuda.is_available(): X = X.cuda() y = y.cuda() # Compute prediction and loss pred = model(X) loss = loss_fn(pred, y) # Backpropagation and optimization optimizer.zero_grad() loss.backward() optimizer.step() # Add up metrics training_loss += loss.item() if config['task'] == 'prosaccade-clf': pred = (pred > 0.5).float() correct += (pred == y).float().sum() # Remove batch from gpu del X del y torch.cuda.empty_cache() loss = training_loss / size logging.info(f"Avg training loss: {loss:>7f}") if config['task'] == 'prosaccade-clf': accuracy = correct / size logging.info(f"Avg training accuracy {accuracy:>8f}") return float(loss), float(accuracy) return float(loss), -1 #@timing_decorator #@profile def validation_loop(dataloader, model, loss_fn): """ Performs one prediction run through the test set stored in the dataloader Prints the loss function computed with the prediction pred and the labels y """ #print("Enter validation:") #get_gpu_memory() #print(torch.cuda.memory_summary()) size = len(dataloader.dataset) val_loss, correct = 0, 0 with torch.no_grad(): for batch, (X, y) in enumerate(dataloader): # Move tensors to GPU if torch.cuda.is_available(): X = X.cuda() y = y.cuda() # Predict pred = model(X) # Compute metrics val_loss += loss_fn(pred, y).item() if config['task'] == 'prosaccade-clf': pred = (pred > 0.5).float() correct += (pred == y).float().sum() # Remove batch from gpu del X del y torch.cuda.empty_cache() loss = val_loss / size logging.info(f"Avg validation loss: {loss:>8f}") if config['task'] == 'prosaccade-clf': accuracy = correct / size logging.info(f"Avg validation accuracy {accuracy:>8f}") return float(loss), float(accuracy) return float(loss), -1 # Can be used for early stopping