Commit 09213b8e authored by Lukas Wolf's avatar Lukas Wolf
Browse files

tf explain

parent a4e1a499
...@@ -6,6 +6,8 @@ import tensorflow.keras as keras ...@@ -6,6 +6,8 @@ import tensorflow.keras as keras
from config import config from config import config
from tensorflow.keras.callbacks import CSVLogger from tensorflow.keras.callbacks import CSVLogger
import logging import logging
import numpy as np
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
class prediction_history(tf.keras.callbacks.Callback): class prediction_history(tf.keras.callbacks.Callback):
def __init__(self,validation_data): def __init__(self,validation_data):
...@@ -114,12 +116,40 @@ class ConvNet(ABC): ...@@ -114,12 +116,40 @@ class ConvNet(ABC):
return self.model return self.model
def fit(self, x, y): def fit(self, x, y):
# Define general callbacks used for training
csv_logger = CSVLogger(config['batches_log'], append=True, separator=';') csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
# early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20) early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5' ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto') ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto')
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42) X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
prediction_ensemble = prediction_history((X_val,y_val)) prediction_ensemble = prediction_history((X_val,y_val))
# Define callbacks for model intepretability experiments
# For integrated gradients, create a validation set with 0 and 1 labels and treat them separately
validation_class_zero = (np.array([el for el, label in zip(X_val, y_val) if np.all(np.argmax(label) == 0)][0:5]))
validation_class_one = (np.array([el for el, label in zip(X_val, y_val) if np.all(np.argmax(label) == 1)][0:5]))
logging.info(validation_class_zero)
logging.info(validation_class_one)
# Define the integrated gradient callbacks for the two classes
#integrated_grad_zero = IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=20, output_dir=config['model_dir'] + '/integrated_grad_zero/')
#integrated_grad_one = IntegratedGradientsCallback(validation_class_one, class_index=1, n_steps=20, output_dir=config['model_dir'] + '/integrated_grad_one/')
hist = self.model.fit(X_train, y_train, verbose=2, batch_size=self.batch_size, validation_data=(X_val,y_val), hist = self.model.fit(X_train, y_train, verbose=2, batch_size=self.batch_size, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble]) epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble, early_stop""", integrated_grad_one, integrated_grad_zero"""])
return hist, prediction_ensemble return hist, prediction_ensemble
def get_val_sets(X, y, label, num_elements):
"""
Return a small validation set with zero labels for the integrated gradients callback
X_val = []
y_val = []
for i in range(len(X)):
if len(X_val) > num_elements: # collected enough elements
return np.array(X_val), np.array(y_val)\
if y[i][0] != label:
continue
X_val.append(X[i])
y_val.append(y[i])
"""
...@@ -29,8 +29,8 @@ TODO: write a proper description how to set the fields in the config ...@@ -29,8 +29,8 @@ TODO: write a proper description how to set the fields in the config
""" """
# Choose which task to run # Choose which task to run
#config['task'] = 'prosaccade-clf' config['task'] = 'prosaccade-clf'
config['task'] = 'gaze-reg' #config['task'] = 'gaze-reg'
#config['task'] = 'angle-reg' #config['task'] = 'angle-reg'
# Choose from which experiment the dataset to load. Can only be chosen for angle-pred and gaze-reg # Choose from which experiment the dataset to load. Can only be chosen for angle-pred and gaze-reg
...@@ -69,8 +69,8 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ ...@@ -69,8 +69,8 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ
config['pretrained'] = False config['pretrained'] = False
# Choose model # Choose model
config['model'] = 'cnn' #config['model'] = 'cnn'
#config['model'] = 'inception' config['model'] = 'inception'
#config['model'] = 'eegnet' #config['model'] = 'eegnet'
#config['model'] = 'deepeye' #config['model'] = 'deepeye'
#config['model'] = 'xception' #config['model'] = 'xception'
......
...@@ -12,7 +12,7 @@ if config['task'] == 'gaze-reg' or config['task'] == 'angle-reg': ...@@ -12,7 +12,7 @@ if config['task'] == 'gaze-reg' or config['task'] == 'angle-reg':
from ensemble_regression import run # gaze regression task from ensemble_regression import run # gaze regression task
from kerasTuner_regression import tune from kerasTuner_regression import tune
elif config['task'] == 'prosaccade-clf': elif config['task'] == 'prosaccade-clf':
from ensemble import run # (anti-) saccade task from ensemble import run # (pro-)saccade task
from kerasTuner import tune from kerasTuner import tune
else: else:
raise Exception("Choose valid task in config.py") raise Exception("Choose valid task in config.py")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment