Commit a32fc593 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

pretrained models for ensemble

parent fa871578
import tensorflow as tf
from tensorflow import keras
from config import config
#from utils.utils import *
import logging
from Regression_ConvNet import Regression_ConvNet
class Pretrained_Model(Regression_ConvNet):
"""
Load a pretrained model and refit it with new data
"""
def __init__(self, input_shape, epochs = 50, verbose=True, batch_size=64, learning_rate=0.01):
super(Pretrained_Model, self).__init__(input_shape, epochs=epochs, verbose=verbose, batch_size=batch_size, learning_rate=learning_rate)
# Overwrite this method to load an existing model instead of building a new one
def _build_model(self, X=[]):
# Define the model paths to the used pretrained models
if config['model'] == 'inception':
name = 'inception'
model_dir= "./archive_runs/proc_speed_task/angle_reg/1618902918_inception_angle-reg_630_fix_sacc_fix_processing_speed_task/"
elif config['model'] == 'cnn':
name = 'cnn'
model_dir = "./archive_runs/proc_speed_task/angle_reg/1618901397_cnn_angle-reg_630_fix_sacc_fix_processing_speed_task/"
else:
raise Exception("No valid path to a pretrained model")
# Load the model
model = keras.models.load_model(model_dir + name + "_best_model.h5", compile=False)
#TODO: log the config of the pretrained model, not the values as in config.py, since they may not be valid
return model
\ No newline at end of file
"""
This file can be called for retraining manually.
Functionality is now contained in Pretrained.py and in the usual ensemble_regression.py file.
Just set config['pretrained'] = True
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras import Model
import numpy as np
from utils.utils import *
from utils.losses import angle_loss
from utils.regression_preprocessing import get_fix_sacc_fix_data
from sklearn.model_selection import train_test_split
import os
from tensorflow.keras.callbacks import CSVLogger
import logging
#from config import config
class prediction_history(tf.keras.callbacks.Callback):
def __init__(self,validation_data):
"""
Predhis is list of arrays, one for each epoch, that contains the predicted values. These can be added and divided by the number of ensemble
"""
self.validation_data = validation_data
self.predhis = []
self.targets = validation_data[1]
def on_epoch_end(self, epoch, logs={}):
"""
Each epoch add an array of predictions to predhis
"""
y_pred = self.model.predict(self.validation_data[0])
self.predhis.append(y_pred)
class Retrain:
def __init__(self, X, y, model_dir, loss, learning_rate=1e-4, epochs=100, name='inception'):
self.path = model_dir
# TODO: This can be loaded from config.p pickle dump (pretrained model may have to create this)
self.loss = loss
self.learning_rate = 1e-4
self.X = X
self.y = y
self.epochs = epochs
self.name = name
self.retrain_dir = self.path + 'retrain/'
if not os.path.exists(self.retrain_dir):
os.makedirs(self.retrain_dir)
def retrain(self):
# Start logging
logging.basicConfig(filename=self.retrain_dir +'info_log', level=logging.INFO, force=True)
logging.info('Started the Logging')
# Load the model
model = keras.models.load_model(self.path + self.name + "_best_model.h5", compile=False)
logging.info("Loaded model")
model.compile(loss=self.loss, optimizer=keras.optimizers.Adam(learning_rate=self.learning_rate))
logging.info("Compiled model")
# Log the shapes of the new dataset to fit on
logging.info("X shape: {}".format(self.X.shape))
logging.info("y shape: {}".format(self.y.shape))
# Split data
X_train, X_val, y_train, y_val = train_test_split(self.X, self.y, test_size=0.2, random_state=42)
# Define callbacks
csv_logger = CSVLogger(self.retrain_dir + 'batches.log', append=True, separator=';')
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)
prediction_ensemble = prediction_history((X_val,y_val))
ckpt_dir = self.retrain_dir + name + "_retrained_best_model.h5"
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
logging.info("Defined callbacks")
# Retrain
hist = model.fit(X_train, y_train, verbose=True, batch_size=64, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
logging.info("Refitted model")
plot_loss(hist, self.retrain_dir, 'retrained_' + self.name, val = True)
logging.info("Done with training and plotting.")
# Predict on validation set
y_pred = model.predict(X_val)
loss = angle_loss(y_pred, y_val)
logging.info("Loss: {}".format(loss))
np.savetxt(self.retrain_dir + "calib_preds_val.txt", y_pred)
np.savetxt(self.retrain_dir + "calib_labels_val.txt", y_val)
logging.info("Done.")
if __name__ == "__main__":
# Set dir and name of the model to retrain
model_dir= "./archive_runs/proc_speed_task/angle_reg/1618902918_inception_angle-reg_630_fix_sacc_fix_processing_speed_task/"
#model_dir = "./archive_runs/proc_speed_task/angle_reg/1618901397_cnn_angle-reg_630_fix_sacc_fix_processing_speed_task/"
name = 'inception'
#name = 'cnn'
# Load the correct data to retrain
#TODO: config must be set correctly to calibration task for this to run
# Would be good if we store the data and then only load it without calling the regression_preprocessing.py method
X, y = get_fix_sacc_fix_data(task='calibration_task')
# Create a trainer object and fit the pretrained model on the existing data
trainer = Retrain(X=X, y=y, model_dir=model_dir, loss=angle_loss, name=name)
trainer.retrain()
\ No newline at end of file
......@@ -65,9 +65,12 @@ finally used for classification.
Cluster can be set to clustering(), clustering2() or clustering3(), where different clusters based on literature are used.
"""
# We can use a model pretrained on processing speed task
config['pretrained'] = True
# Choose model
#config['model'] = 'cnn'
config['model'] = 'inception'
config['model'] = 'cnn'
#config['model'] = 'inception'
#config['model'] = 'eegnet'
#config['model'] = 'deepeye'
#config['model'] = 'xception'
......@@ -77,7 +80,7 @@ config['model'] = 'inception'
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-5 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4, for inception on angle 1e-5
config['regularization'] = 0 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1, fix_sac_fix: 5, for inception on angle 0
config['epochs'] = 100
config['epochs'] = 150
config['batch_size'] = 64
# Choose the kerastuner or an ensemble of models
......@@ -106,7 +109,7 @@ if config['task'] != 'prosaccade-clf':
config['padding'] = 'repeat' # options: zero, repeat
config['min_fixation'] = 50 # min number of samples for the gaze fixation
config['max_fixation'] = 150 # max number of samples for the gaze fixation
config['fixation_padlength'] = config['max_fixation'] # for the proc speed task
config['fixation_padlength'] = 300 # for the proc speed task
if config['dataset'] == 'calibration_task':
config['max_fixation'] = 1000
config['fixation_padlength'] = 300 # cut off the fixation at this length
......@@ -146,11 +149,11 @@ if config['task'] != 'prosaccade-clf':
config['deepeye']['input_shape'] = (config['max_saccade'], 129)
config['xception']['input_shape'] = (config['max_saccade'], 129)
elif config['data_mode'] == 'sacc_fix':
config['cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + config['max_fixation'], 129)
config['cnn']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['pyramidal_cnn']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['inception']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['deepeye']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
config['xception']['input_shape'] = (config['max_saccade'] + config['fixation_padlength'], 129)
# Choose the shapes for angle pred depending on dataset
elif config['data_mode'] == 'fix_sacc_fix' and config['dataset'] == 'processing_speed_task':
config['cnn']['input_shape'] = (config['max_saccade'] + 2 * config['fixation_padlength'], 129)
......@@ -181,7 +184,8 @@ else:
# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']
#model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']
model_folder_name = timestamp + "_pretrained_" + config['model'] if config['pretrained'] else timestamp + "_" + config['model']
# Modify the model folder name depending on which task tuns
model_folder_name += "_" + config['task']
if config['task'] != 'prosaccade-clf':
......@@ -200,4 +204,9 @@ if not os.path.exists(config['model_dir']):
os.makedirs(config['model_dir'])
config['info_log'] = config['model_dir'] + '/' + 'info.log'
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'
\ No newline at end of file
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'
# Save config to model dir
import pickle
config_path = config['model_dir'] + "/config.p"
pickle.dump(config, open(config_path, "wb"))
\ No newline at end of file
......@@ -12,6 +12,7 @@ from Xception.Regression_xception import Regression_XCEPTION
from DeepEye.Regression_deepeye import Regression_DEEPEYE
from PyramidalCNN.Regression_PyramidalCNN import Regression_PyramidalCNN
from Siamese.Siamese import Siamese_ConvNet
from Pretrained.Pretrained import Pretrained_Model
#TODO: rewrite the other classes
#from DeepEyeRNN.deepeyeRNN import Classifier_DEEPEYE_RNN
......@@ -36,7 +37,10 @@ def run(trainX, trainY):
for i in range(config['ensemble']):
print('Beginning model number {}/{} ...'.format(i+1, config['ensemble']))
if config['model'] == 'cnn':
if config['pretrained']:
reg = Pretrained_Model(input_shape=config['inception']['input_shape'], batch_size=config['batch_size'],
epochs=config['epochs'], learning_rate=config['learning_rate'])
elif config['model'] == 'cnn':
reg = Regression_CNN(input_shape=config['cnn']['input_shape'], kernel_size=64, epochs = config['epochs'],
nb_filters=16, verbose=True, batch_size=config['batch_size'], use_residual=True, depth=10,
learning_rate=config['learning_rate'])
......@@ -120,10 +124,10 @@ def run(trainX, trainY):
hist.history['val_loss'] = np.array(loss)
# plot ensemble loss
plot_loss(hist, config['model_dir'], config['model'], val = True)
if not config['gaze-reg']:
#if config['task'] == 'prosaccade-clf':
# Plot also accuracy for the saccade classification task
hist.history['val_accuracy'] = accuracy
plot_acc(hist, config['model_dir'], config['model'], val = True)
# hist.history['val_accuracy'] = accuracy
# plot_acc(hist, config['model_dir'], config['model'], val = True)
#if config['split']:
#config['model'] = config['model'] + '_cluster'
......
......@@ -18,10 +18,10 @@ else:
raise Exception("Choose valid task in config.py")
def main():
# Do some logging
logging.basicConfig(filename=config['info_log'], level=logging.INFO)
logging.info('Started the Logging')
start_time = time.time()
if config['task'] == 'gaze-reg':
logging.info("Running the gaze regression task")
logging.info("Using data from {}".format(config['dataset']))
......@@ -38,19 +38,18 @@ def main():
logging.info("Using saccades between {} ms and {} ms, 1 sample equals 2ms".format((2 * config['min_saccade']), (2 * config['max_saccade'])))
else:
logging.info("Running the saccade classification (prosaccade) task")
# Log some parameters to better distinguish between tasks
logging.info("Loss: {}".format(config['loss']))
logging.info("Learning rate: {}".format(config['learning_rate']))
logging.info("Regularization: {}".format(config['regularization']))
logging.info("Batch size: {}".format(config['batch_size']))
logging.info("Maximal number of epochs per model: {}".format(config['epochs']))
if config['run'] == "kerastuner":
logging.info("Running the keras-tuner")
else:
logging.info("Running the ensemble with {} ensemble models".format(config['ensemble']))
# Load the data
try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
except:
......
-1.569377870617334400e-02
9.905627123327751526e-01
-8.387599197583335275e-01
7.485984771076863165e-02
2.604072697985976848e-01
2.925486831018828671e+00
2.690782309810642592e-01
-2.426707998145916437e-02
2.358581120505784501e+00
-1.977181763948293813e+00
-2.006419677059015014e-02
-4.974059723255206539e-02
5.747758616989792069e-01
-4.999859326240274626e-01
1.418861227307625317e+00
3.045256752389511767e+00
2.938669983478068737e+00
-2.470055117386723786e+00
-5.830837708949231091e-03
3.220693302863869567e-01
-2.985636771349959417e+00
-1.468982412501205648e+00
-3.119789255690092578e+00
-3.018163156233669930e+00
-4.406317145155799753e-02
-1.020662209379181729e+00
8.326681269903306315e-02
1.861734718001421651e+00
8.728571577454532493e-01
2.181672418791444912e+00
-2.988696959099488293e+00
-2.888631702489443120e-01
3.083587769406258339e+00
-1.490363683116479487e+00
-7.726599989073846819e-01
3.019973630328444347e+00
9.803607479295414762e-03
-8.354889595065224939e-02
2.928002149744660887e+00
2.730636660259424531e+00
2.957684933035163777e+00
-1.332097878241998989e-01
-1.309335639833030251e+00
-9.023554157673198128e-01
-2.870498027280821596e+00
5.258306161094142978e-02
2.844049338616524203e+00
-2.768979631322544543e-01
2.939144764578647706e+00
-7.249841301142836980e-02
-1.495212436154226276e+00
2.902173052291449284e+00
-2.472529371225245243e-01
-1.129783818079477670e-01
-2.590094223821952021e-02
-5.158090854765225997e-01
-2.211504146763661804e+00
2.937661239093980470e+00
2.604463817858352570e+00
-9.258994662123083350e-03
8.285811777345777340e-01
-3.107400752860384419e+00
2.946907114205843392e+00
-2.711460186948317030e+00
-7.536344149688124772e-02
-8.282116425362006096e-01
-1.173559227464331167e+00
-3.304401752458894226e-01
1.723710860275881940e+00
2.213660483186150746e-02
-2.881032664844076496e-01
-3.069483135720151079e+00
3.050033000571034325e+00
2.794303552283956132e-02
2.931715079047716888e+00
6.838634359050342615e-02
7.732530116005361664e-01
3.127237183411107346e-01
-2.963793593639311830e+00
-1.106572211738958006e-01
3.033201133997259547e+00
-8.216752236170034307e-02
-2.942938050625421020e-01
2.604463817858352570e+00
-7.360363179873103956e-01
-1.245211575375718249e+00
-2.481263600304458361e+00
4.529980338618173723e-03
-9.248899412353626948e-02
3.097918643521019089e+00
2.856862215067081934e+00
1.408357582286071628e-02
1.141220189442999189e+00
-1.468476332747937541e+00
1.335748639083513822e-01
-6.760514923198104875e-02
-4.017130349329960798e-02
2.889089380004783436e+00
-7.383005771068212875e-02
1.177570689477988597e-02
-3.039095954339441796e+00
-3.326268344027085572e-02
-1.638715389008260082e-01
-2.732035715480310989e-01
2.993696302473816129e+00
3.224688243525469789e-02
-3.136320294710741585e+00
1.470783553884023320e-01
-4.274960430563995439e-02
2.607734738939222918e+00
-1.856937253970588841e+00
-4.570142071260677730e-01
-1.443942960276612709e-02
1.331053217924440091e+00
-2.020759550558271389e+00
-2.361908995975648684e-01
2.963871973980228613e+00
2.971009976260176266e+00
2.187093264190967279e+00
2.289392075206107269e-01
6.336714987229120633e-01
-2.145054574410898418e-01
-1.298628292017279910e-02
-3.291266060846112507e-01
2.854821096778515699e+00
-2.720573233461073936e+00
-4.076315054576273300e-01
-2.084919971601557204e-02
-1.852653531510818630e-01
8.297041248482251996e-01
-3.831425748719927693e-01
-2.982352840321746523e+00
2.783547936004760359e+00
2.110933332227473702e-01
-7.878396098914371148e-02
-4.106753093170086542e-03
-2.866280289656689018e+00
-3.736445601233489011e-01
-3.063102550952461200e+00
2.749507221544010616e+00
4.188729781369053828e-01
-9.563757729144757580e-02
3.092785498289880586e+00
-2.107659935296665843e+00
-3.118034336081620950e+00
-1.320276019094073217e+00
1.825135652605773595e+00
8.973286609801788272e-01
2.940328820400537821e-02
-2.740950735471020305e+00
1.162140617789982483e-01
2.711830374493104934e+00
-3.001257284337600861e+00
2.964534861755707337e-02
-2.170305902296540079e+00
3.097697003160556850e+00
2.984250218856216108e+00
-3.103531573056040660e+00
-3.091838909967867810e-01
-1.793885831180670687e+00
-1.752844438297006535e-01
1.202144078914834546e+00
2.441826262084002774e+00
-2.161138107362541305e+00
2.920763883860445276e+00
2.595244034250302700e+00
3.141592653589793116e+00
5.414125486035384904e-01
-1.394576850704642157e-01
-3.081060633807686688e+00
9.090160824994200639e-02
3.108594926595352526e+00
-2.955160301344298635e+00
-8.987851527037607480e-02
-7.794026669554618425e-02
-4.755005021702331525e-01
-3.127294295371653376e+00
-1.095452279864827794e-01
-3.734526624452318838e-02
3.061600663342845863e+00
-1.320481886992889153e-01
1.566018769820152201e-01
2.403035203108315088e+00
-8.667881840940398763e-02
-3.136106815204574350e+00
6.696389456766392856e-01
-1.210272198092051943e-01
8.794015878740363279e-01
-1.143293833891327882e+00
-3.589617642689647142e-02
7.801626091410410746e-01
2.907994379255161910e+00
2.339537515122410705e-01
-1.900676783371482825e+00
-2.682720557474122192e+00
3.096169374168214805e+00
3.483855481231053486e-01
-3.093712608819012999e+00
2.858402695691035333e+00
2.977663487765731976e+00
5.880026035475821589e-01
2.963297115853840147e+00
2.454775003831148439e+00
1.288555861222714594e+00
-1.202498645394086907e+00
-5.036508011241763727e-02
2.924175931407205997e+00
-1.728150341672066714e-01
-3.076015177027011660e+00
2.972021005419188899e+00
-1.599131231582193169e-01
-6.105709496640513434e-01
-3.031956534852027207e+00
2.325162281046291510e-02
-2.598641513158200134e+00
8.393364046641891285e-01
-1.890831617007508569e+00
-1.022743387006819438e+00
2.885700371938200171e+00
3.014428677666983525e+00
-2.502158925039545512e+00
8.490179344972262343e-02
2.229321761073140284e-01
2.917990044385310622e+00
-1.848625065184948957e+00
-3.111971690160667325e+00
2.997535775130222735e+00
2.808844899180416022e+00
1.265755183405196244e-02
-1.290139769797056690e+00
2.657601996624314386e-02
-3.025149333465571022e+00
3.051535902624719299e+00
-3.087150592261346027e+00
-1.930800333866087848e-01
1.364940050219988210e+00
6.639787501262337521e-01
4.798573377461948924e-02
9.184612626473816210e-02
3.015025273272126771e+00
-2.585668958766945469e+00
3.034164699985137048e+00
-1.923297470543787191e-01
3.038924609973209012e+00
-2.270887196709554079e+00
-6.903011045793226719e-02
-8.298791897663437833e-01
3.020240653019029775e+00
3.094789874422513076e+00
-1.199509546604792254e+00
3.011890116433881381e+00
1.197707523422350206e+00
-4.213231730842004197e-01
2.972097977047294393e+00
-2.682959911180867185e-01
3.093787796241713295e+00
-1.455661508811181004e-01
-3.074354098852946748e+00
2.125078562177788732e-02
-3.940550815586880695e-01
2.206569160573224053e+00
2.858783442290189658e+00
1.304033078891326525e-01
1.914900880583611853e+00
1.016520613580307042e-01
3.131984671348004845e+00
-3.040880144857755485e-01
-9.087935490614083012e-01
-1.328550964163080161e+00
-5.395295568686268028e-01
-2.821880832352217094e+00
9.702191115280822675e-01
-5.861086664181475170e-01
-2.232878506709978506e-01
-1.203707534221476250e+00
-4.545569503357452801e-01
4.735810571000109714e-01
7.160650297092397976e-01
-1.747853680290708978e+00
-2.489147854822932437e+00
-3.900459820480539630e-02
-2.374850346299585760e-02
-7.233733053859502105e-02
-4.883966051018211359e-03
1.495870197929479728e-01
-1.651471906758492869e+00
2.855468149842941727e+00
-9.624291267538649386e-01
-9.045718571599720148e-01
-2.659412303452633353e+00
2.935342170219171987e+00
-2.614909854364439568e+00
-8.823157911918111385e-01
2.897766594839563226e+00
-1.580895953957334524e-02
-1.465693080441003542e-02
3.043865746231074532e+00
1.488312989548379761e+00
5.508206995607441148e-02
1.429143482512010566e-01
9.747602311895268157e-01
-1.933460694806757654e-02
-2.926814738662495419e+00
-6.475575071727801557e-02
-2.817553773667931694e+00
-1.427982258831936024e-01
3.137576610923951126e+00
-3.040213951463976105e+00
1.772150124338878419e+00
-2.285978665306664492e+00
-3.135078087769354926e+00
-1.307952802018856442e+00
2.956404015365889926e+00
-1.250178313918005879e-01
-3.433199297564421548e-01
-4.779721365221956675e-01
-8.713646265688770320e-01
-1.390108405312734108e-01
1.539875038534305463e-01
1.042901335331059692e-01
5.407630448257430555e-02
-3.128606370669620773e+00
3.169561493790847614e-03
4.165268265415537075e-01
3.108860984926226756e+00
-1.724700809632341736e+00
-1.125036975771020881e-01
3.104279662949123608e+00
3.470827822527487727e-02
3.118535943553394141e+00
9.420986411587577505e-03
3.133900497614144243e+00
3.106924468380293902e+00
-3.018101777623835025e-01
3.009879068126773305e+00
-8.652441674559363527e-03
-2.333121662302201305e+00
-1.410699706177119372e+00
-3.996182689219310302e-01
-1.987154007025075109e-01
-3.078485409398897144e+00