config.py 4.48 KB
Newer Older
1
2
# configuration used by the training and evaluation methods
# let's keep it here to have a clean code on other methods that we try
3
4
import time
import os
okiss's avatar
okiss committed
5
from cluster import clustering
6
7
8
9
10
11
12
13
14
15
16
config = dict()

##################################################################
# Please note that the following fields will be set by our scripts to re-train and re-evaluate our model.

# Where experiment results are stored.
config['log_dir'] = './runs/'
# In case the pre/post-processing scripts generate intermediate results, we may use config['tmp_dir'] to store them.
config['tmp_dir'] = './tmp/'
# Path to training, validation and test data folders.
# PS: Note that we have to upload the data to the server!!!
Ard Kastrati's avatar
Ard Kastrati committed
17
18
# config['data_dir'] = '/cluster/home/your_username/data/'
config['data_dir'] = './'
19
20
21
config['data_dir_server'] = '/cluster/project/infk/zigeng/preprocessed2/'
# Path of root
config['root_dir'] = '.'
22

23
24
25
26
27
##################################################################
# You can modify the rest or add new fields as you need.

# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4
28
29
30

"""
Models:
31
32
33
34
cnn: First try: CNN to predict movement towards left or right (prosaccade) with 1 second data.
inception: The InceptionTime baseline
eegnet: The other baseline
deepeye: Our method
35
36
"""

37
# Choosing model
38
39
40
config['model'] = 'cnn'
config['downsampled'] = True
config['split'] = True
okiss's avatar
okiss committed
41
config['cluster'] = clustering()
42
43
44
if config['split']:
    config['model'] = config['model'] + '_cluster'

45

46
47
48
49
50
51
config['trainX_file'] = 'noweEEG.mat' if config['downsampled'] else 'all_EEGprocuesan.mat'
config['trainY_file'] = 'all_trialinfoprosan.mat'
config['trainX_variable'] = 'noweEEG' if config['downsampled'] else 'all_EEGprocuesan'
config['trainY_variable'] = 'all_trialinfoprosan'

# CNN
52
config['cnn'] = {}
53
54
55
56
# InceptionTime
config['inception'] = {}
# DeepEye
config['deepeye'] = {}
57
58
# DeepEye2
config['deepeye2'] = {}
59
60
# DeepEye2
config['deepeye3'] = {}
61
62
# Xception
config['xception'] = {}
63
64
# EEGNet
config['eegnet'] = {}
zpgeng's avatar
zpgeng committed
65
66
# LSTM-DeepEye
config['deepeye-lstm'] = {}
67

68
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
69
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
70
config['deepeye2']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
71
config['deepeye3']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
72
73
config['xception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)

74
75
76
77
78
config['eegnet']['channels'] = 129
config['eegnet']['samples'] = 125 if config['downsampled'] else 500
config['deepeye']['input_shape'] = (129, 125) if config['downsampled'] else (129, 500)
config['deepeye']['channels'] = 129
config['deepeye']['samples'] = 125 if config['downsampled'] else 500
zpgeng's avatar
zpgeng committed
79
config['deepeye-lstm']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']

if config['downsampled']:
    model_folder_name += '_downsampled'

config['model_dir'] = os.path.abspath(os.path.join(config['log_dir'], model_folder_name))
if not os.path.exists(config['model_dir']):
    os.makedirs(config['model_dir'])

config['info_log'] = config['model_dir'] + '/' + 'info.log'
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'


# Before we were loading directly the data from the server, or used the code to merge them.
# Deprecated now. We don't really use the functions anymore in the IOHelper that make use of the following configurations.

99
100
101
config['cnn']['trainX_variable1'] = "EEGprocue"
config['cnn']['trainX_variable2'] = "data"
config['cnn']['trainX_filename'] = "EEGprocue"
102

103
104
105
config['cnn']['trainY_variable1'] = "trialinfopro"
config['cnn']['trainY_variable2'] = "cues"
config['cnn']['trainY_filename'] = "trialinfocuelocked"
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
config['inception']['trainX_variable1'] = "EEGprocue"
config['inception']['trainX_variable2'] = "data"
config['inception']['trainX_filename'] = "EEGprocue"

config['inception']['trainY_variable1'] = "trialinfopro"
config['inception']['trainY_variable2'] = "cues"
config['inception']['trainY_filename'] = "trialinfocuelocked"

config['deepeye']['trainX_variable1'] = "EEGprocue"
config['deepeye']['trainX_variable2'] = "data"
config['deepeye']['trainX_filename'] = "EEGprocue"

config['deepeye']['trainY_variable1'] = "trialinfopro"
config['deepeye']['trainY_variable2'] = "cues"
okiss's avatar
okiss committed
121
config['deepeye']['trainY_filename'] = "trialinfocuelocked"