Commit 09213b8e authored by Lukas Wolf's avatar Lukas Wolf
Browse files

tf explain

parent a4e1a499
......@@ -6,6 +6,8 @@ import tensorflow.keras as keras
from config import config
from tensorflow.keras.callbacks import CSVLogger
import logging
import numpy as np
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
class prediction_history(tf.keras.callbacks.Callback):
def __init__(self,validation_data):
......@@ -114,12 +116,40 @@ class ConvNet(ABC):
return self.model
def fit(self, x, y):
# Define general callbacks used for training
csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
# early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto')
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
prediction_ensemble = prediction_history((X_val,y_val))
# Define callbacks for model intepretability experiments
# For integrated gradients, create a validation set with 0 and 1 labels and treat them separately
validation_class_zero = (np.array([el for el, label in zip(X_val, y_val) if np.all(np.argmax(label) == 0)][0:5]))
validation_class_one = (np.array([el for el, label in zip(X_val, y_val) if np.all(np.argmax(label) == 1)][0:5]))
logging.info(validation_class_zero)
logging.info(validation_class_one)
# Define the integrated gradient callbacks for the two classes
#integrated_grad_zero = IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=20, output_dir=config['model_dir'] + '/integrated_grad_zero/')
#integrated_grad_one = IntegratedGradientsCallback(validation_class_one, class_index=1, n_steps=20, output_dir=config['model_dir'] + '/integrated_grad_one/')
hist = self.model.fit(X_train, y_train, verbose=2, batch_size=self.batch_size, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble, early_stop""", integrated_grad_one, integrated_grad_zero"""])
return hist, prediction_ensemble
def get_val_sets(X, y, label, num_elements):
"""
Return a small validation set with zero labels for the integrated gradients callback
X_val = []
y_val = []
for i in range(len(X)):
if len(X_val) > num_elements: # collected enough elements
return np.array(X_val), np.array(y_val)\
if y[i][0] != label:
continue
X_val.append(X[i])
y_val.append(y[i])
"""
......@@ -29,8 +29,8 @@ TODO: write a proper description how to set the fields in the config
"""
# Choose which task to run
#config['task'] = 'prosaccade-clf'
config['task'] = 'gaze-reg'
config['task'] = 'prosaccade-clf'
#config['task'] = 'gaze-reg'
#config['task'] = 'angle-reg'
# Choose from which experiment the dataset to load. Can only be chosen for angle-pred and gaze-reg
......@@ -69,8 +69,8 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ
config['pretrained'] = False
# Choose model
config['model'] = 'cnn'
#config['model'] = 'inception'
#config['model'] = 'cnn'
config['model'] = 'inception'
#config['model'] = 'eegnet'
#config['model'] = 'deepeye'
#config['model'] = 'xception'
......
......@@ -12,7 +12,7 @@ if config['task'] == 'gaze-reg' or config['task'] == 'angle-reg':
from ensemble_regression import run # gaze regression task
from kerasTuner_regression import tune
elif config['task'] == 'prosaccade-clf':
from ensemble import run # (anti-) saccade task
from ensemble import run # (pro-)saccade task
from kerasTuner import tune
else:
raise Exception("Choose valid task in config.py")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment