Commit 341469b4 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

bug fix

parent d459de8e
......@@ -4,7 +4,7 @@ from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow.keras as keras
from config import config
from keras.callbacks import CSVLogger
from tensorflow.keras.callbacks import CSVLogger
import logging
class prediction_history(tf.keras.callbacks.Callback):
......
......@@ -112,7 +112,7 @@ class Regression_ConvNet(ABC):
# Optional: add some more dense layers here
#gap_layer = tf.keras.layers.Dense(300)(gap_layer)
gap_layer = tf.keras.layers.Dense(50)(gap_layer)
#gap_layer = tf.keras.layers.Dense(50)(gap_layer)
if config['data_mode'] == "fix_sacc_fix":
output_layer = tf.keras.layers.Dense(1, activation='linear')(gap_layer) # only predict the angle in this task
......
......@@ -46,11 +46,11 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ
# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-3 # fix only: 1e-2, sac only: 1e-3, sac_fix: 1e-3 , fix_sac_fix: 1e-4
config['regularization'] = 1e-1 # fix only: 1e-3, sac only: 1e-2, sac_fix: 1e-1, fix_sac_fix: 5
config['epochs'] = 5
config['epochs'] = 2
config['batch_size'] = 64
# Choose experiment
config['gaze-reg'] = True # Set to False if you want to run the saccade classification task
config['gaze-reg'] = False # Set to False if you want to run the saccade classification task
config['data-fraction'] = 0.1 # Set to 1.0 if you want to use the whole dataset, experimental feature only for regression task \
# Choose which dataset to run the gaze regression on
......@@ -71,11 +71,11 @@ config['model'] = 'cnn'
# Choose the kerastuner or an ensemble of models
#config['run'] = 'kerastuner'
config['run'] = 'ensemble'
config['ensemble'] = 5 #number of models in the ensemble method
config['ensemble'] = 1 #number of models in the ensemble method
# Other functions that can be chosen optionally
config['sanity_check'] = True
config['plot_filters'] = True
config['sanity_check'] = False
config['plot_filters'] = False
# Set loss automatically depending on the dataset/task to run
if config['data_mode'] == 'fix_sacc_fix':
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -8,8 +8,8 @@ from CNN.CNN import Classifier_CNN
from PyramidalCNN.PyramidalCNN import Classifier_PyramidalCNN
from DeepEye.deepeye import Classifier_DEEPEYE
from DeepEyeRNN.deepeyeRNN import Classifier_DEEPEYE_RNN
from Xception.xception import Classifier_XCEPTION
from InceptionTime.inception import Classifier_INCEPTION
from Xception.Xception import Classifier_XCEPTION
from InceptionTime.Inception import Classifier_INCEPTION
from EEGNet.eegNet import Classifier_EEGNet
import numpy as np
......
......@@ -60,8 +60,9 @@ def run(trainX, trainY):
learning_rate=config['learning_rate'], regularization=config['regularization'])
elif config['model'] == 'pyramidal_cnn':
reg = Regression_PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=config['epochs'], batch_size=config['batch_size'],
learning_rate=config['learning_rate'], regularization=config['regularization'])
reg = Regression_PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=config['epochs'], depth=8,
batch_size=config['batch_size'], learning_rate=config['learning_rate'],
regularization=config['regularization'])
elif config['model'] == 'siamese':
reg = Siamese_ConvNet(input_shape=None, use_residual=True, batch_size=config['batch_size'],
......
......@@ -6,8 +6,8 @@ import logging
from CNN.CNN import Classifier_CNN
from DeepEye.deepeye import Classifier_DEEPEYE
from DeepEyeRNN.deepeyeRNN import Classifier_DEEPEYE_RNN
from Xception.xception import Classifier_XCEPTION
from InceptionTime.inception import Classifier_INCEPTION
from Xception.Xception import Classifier_XCEPTION
from InceptionTime.Inception import Classifier_INCEPTION
from EEGNet.eegNet import Classifier_EEGNet
......
......@@ -4,7 +4,7 @@
#SBATCH --output=log/%j.out # where to store the output (%j is the JOBID), subdirectory must exist
#SBATCH --error=log/%j.err # where to store error messages
#SBATCH --gres=gpu:1
#SBATCH --mem=60G
#SBATCH --mem=20G
echo "Running on host: $(hostname)"
echo "In directory: $(pwd)"
......
......@@ -102,8 +102,8 @@ def get_fix_data(verbose=True):
else:
raise Exception("Choose a valid padding scheme in config.py")
x_list.append(x_datapoint)
y_list.append(y_datapoint)
x_list.append([x_datapoint])
y_list.append([y_datapoint])
X = np.asarray(x_list)
# Reshape data and normalize it
......@@ -286,9 +286,9 @@ def get_sacc_fix_data(verbose=True):
y_datapoint = np.array([fix_avg_x, fix_avg_y])
# Append to X and y
x_list.append(x_datapoint)
y_list.append(y_datapoint)
x_list.append([x_datapoint])
y_list.append([y_datapoint])
X = np.asarray(x_list)
X = X[:,:,:129] # Cut off the last 4 columns (time, x, y, pupil size)
# Normalize the data
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment