To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit e950002d authored by Lukas Wolf's avatar Lukas Wolf
Browse files

debug early stopping

parent dbdb07bd
......@@ -117,7 +117,7 @@ class BaseNet(nn.Module):
prediction_ensemble = Prediction_history(dataloader=test_dataloader, model=self)
# Train the model
epochs = config['epochs']
metrics = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]}
metrics = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]} if config['task'] == 'prosaccade-clf' else {'train_loss':[], 'val_loss':[]}
curr_val_loss = sys.maxsize # For early stopping
patience = 0
for t in range(epochs):
......@@ -149,6 +149,7 @@ class BaseNet(nn.Module):
if val_loss_epoch > curr_val_loss:
patience +=1
else:
curr_val_loss = val_acc_epoch
patience = 0
# Plot and save metrics
plot_metrics(metrics['train_loss'], metrics['val_loss'], output_dir=config['model_dir'], metric='loss', model_number=self.model_number)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment