To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit d1f5b879 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

log which model is trained on which data and prep

parent 4214a425
......@@ -31,11 +31,13 @@ def main():
logging.info(trainX.shape)
"""
# Create trainer that runs ensemble of models
#benchmark_task('prosaccade-clf')
benchmark_task('gaze-reg')
#benchmark_task('angle-reg')
#benchmark_task('amplitude-reg')
num_benchmarks = 1
for i in range(num_benchmarks):
benchmark_task('prosaccade-clf')
benchmark_task('gaze-reg')
#benchmark_task('angle-reg')
#benchmark_task('amplitude-reg')
# select_best_model()
# comparison_plot_loss()
......@@ -68,6 +70,7 @@ def benchmark_task(task):
log_config()
start_time = time.time()
trainer = Trainer(config)
logging.info(f"Training {config['model']} on {task} with {config['preprocessing']} preprocessing")
trainer.train()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
start_time = time.time()
......
......@@ -19,6 +19,7 @@ def train_loop(dataloader, model, loss_fn, optimizer):
#get_gpu_memory()
#print(torch.cuda.memory_summary())
size = len(dataloader)
num_datapoints = len(dataloader.dataset)
training_loss, correct = 0, 0
for batch, (X, y) in enumerate(dataloader):
# Move tensors to GPU
......@@ -47,7 +48,7 @@ def train_loop(dataloader, model, loss_fn, optimizer):
loss = training_loss / size
logging.info(f"Avg training loss: {loss:>7f}")
if config['task'] == 'prosaccade-clf':
accuracy = correct / size
accuracy = correct / num_datapoints
logging.info(f"Avg training accuracy {accuracy:>8f}")
return float(loss), float(accuracy)
return float(loss), -1
......
......@@ -54,8 +54,9 @@ def compute_loss(loss_fn, dataloader, pred_list, nb_models):
pred = pred_list[batch]
pred = torch.div(pred, nb_models).float() # is already on gpu
loss += loss_fn(y, pred)
loss += loss_fn(pred, y)
print(f"acc loss: {loss}, size: {size}, epoch loss: {loss/size}")
return loss / size
#@profile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment