Commit f7a97ea0 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

prepare for restructuring

parent 6b5eefdd
"""
Class for ensemble training
TODO: implement
"""
\ No newline at end of file
......@@ -143,11 +143,9 @@ class Regression_ConvNet(ABC):
prediction_ensemble = prediction_history((X_val,y_val))
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
# Create a callback for tensorboard
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=config['tensorboard_log_dir'], histogram_freq=1)
# Fit model
hist = self.model.fit(X_train, y_train, verbose=verbose, batch_size=self.batch_size, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble, tensorboard_callback])
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
# Log how good predictions in x and y directions are
if config['sanity_check'] and not config['data_mode'] == 'fix_sacc_fix':
......
......@@ -29,9 +29,9 @@ TODO: write a proper description how to set the fields in the config
"""
# Choose which task to run
config['task'] = 'prosaccade-clf'
#config['task'] = 'prosaccade-clf'
#config['task'] = 'gaze-reg'
#config['task'] = 'angle-reg'
config['task'] = 'angle-reg'
# Choose from which experiment the dataset to load. Can only be chosen for angle-pred and gaze-reg
if config['task'] != 'prosaccade-clf':
......
......@@ -7,7 +7,6 @@ from scipy import io
import h5py
import logging
import time
import os
# Import the correct functions depending on the task
if config['task'] == 'gaze-reg' or config['task'] == 'angle-reg':
......@@ -57,11 +56,9 @@ def main():
except:
raise Exception("Could not load mat data")
"""
if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape)
"""
if config['run'] == 'kerastuner':
tune(trainX,trainY)
......@@ -71,7 +68,7 @@ def main():
raise Exception("Please choose a valid run scheme in config.py")
# Select model and plot results
# select_best_model() TODO: review this for the regression task
# select_best_model()
# comparison_plot_loss()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment