config.py 3.86 KB
Newer Older
1
2
# configuration used by the training and evaluation methods
# let's keep it here to have a clean code on other methods that we try
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import time
import os

config = dict()

##################################################################
# Please note that the following fields will be set by our scripts to re-train and re-evaluate our model.

# Where experiment results are stored.
config['log_dir'] = './runs/'
# In case the pre/post-processing scripts generate intermediate results, we may use config['tmp_dir'] to store them.
config['tmp_dir'] = './tmp/'
# Path to training, validation and test data folders.
# PS: Note that we have to upload the data to the server!!!
Ard Kastrati's avatar
Ard Kastrati committed
17
18
# config['data_dir'] = '/cluster/home/your_username/data/'
config['data_dir'] = './'
19
20
21
config['data_dir_server'] = '/cluster/project/infk/zigeng/preprocessed2/'
# Path of root
config['root_dir'] = '.'
22

23
24
25
26
27
##################################################################
# You can modify the rest or add new fields as you need.

# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4
28
29
30

"""
Models:
31
32
33
34
cnn: First try: CNN to predict movement towards left or right (prosaccade) with 1 second data.
inception: The InceptionTime baseline
eegnet: The other baseline
deepeye: Our method
35
36
"""

37
# Choosing model
38
39
config['model'] = 'deepeye'
config['downsampled'] = True
40

41
42
43
44
45
46
config['trainX_file'] = 'noweEEG.mat' if config['downsampled'] else 'all_EEGprocuesan.mat'
config['trainY_file'] = 'all_trialinfoprosan.mat'
config['trainX_variable'] = 'noweEEG' if config['downsampled'] else 'all_EEGprocuesan'
config['trainY_variable'] = 'all_trialinfoprosan'

# CNN
47
config['cnn'] = {}
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# InceptionTime
config['inception'] = {}
# DeepEye
config['deepeye'] = {}
# EEGNet
config['eegnet'] = {}

config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['eegnet']['channels'] = 129
config['eegnet']['samples'] = 125 if config['downsampled'] else 500
config['deepeye']['input_shape'] = (129, 125) if config['downsampled'] else (129, 500)
config['deepeye']['channels'] = 129
config['deepeye']['samples'] = 125 if config['downsampled'] else 500

# You can set descriptive experiment names or simply set empty string ''.
# config['model_name'] = 'skeleton_code'

# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']

if config['downsampled']:
    model_folder_name += '_downsampled'

config['model_dir'] = os.path.abspath(os.path.join(config['log_dir'], model_folder_name))
if not os.path.exists(config['model_dir']):
    os.makedirs(config['model_dir'])

config['info_log'] = config['model_dir'] + '/' + 'info.log'
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'


# Before we were loading directly the data from the server, or used the code to merge them.
# Deprecated now. We don't really use the functions anymore in the IOHelper that make use of the following configurations.

83
84
85
config['cnn']['trainX_variable1'] = "EEGprocue"
config['cnn']['trainX_variable2'] = "data"
config['cnn']['trainX_filename'] = "EEGprocue"
86

87
88
89
config['cnn']['trainY_variable1'] = "trialinfopro"
config['cnn']['trainY_variable2'] = "cues"
config['cnn']['trainY_filename'] = "trialinfocuelocked"
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
config['inception']['trainX_variable1'] = "EEGprocue"
config['inception']['trainX_variable2'] = "data"
config['inception']['trainX_filename'] = "EEGprocue"

config['inception']['trainY_variable1'] = "trialinfopro"
config['inception']['trainY_variable2'] = "cues"
config['inception']['trainY_filename'] = "trialinfocuelocked"

config['deepeye']['trainX_variable1'] = "EEGprocue"
config['deepeye']['trainX_variable2'] = "data"
config['deepeye']['trainX_filename'] = "EEGprocue"

config['deepeye']['trainY_variable1'] = "trialinfopro"
config['deepeye']['trainY_variable2'] = "cues"
105
config['deepeye']['trainY_filename'] = "trialinfocuelocked"