config.py 4.7 KB
Newer Older
1
2
# configuration used by the training and evaluation methods
# let's keep it here to have a clean code on other methods that we try
3
4
import time
import os
5
6
7
from Clusters.cluster import clustering as clustering
from Clusters.cluster2 import clustering as clustering2
from Clusters.cluster3 import clustering as clustering3
8
9
10
11
12
13
14
15
16
config = dict()

##################################################################
# Please note that the following fields will be set by our scripts to re-train and re-evaluate our model.

# Where experiment results are stored.
config['log_dir'] = './runs/'
# Path to training, validation and test data folders.
# PS: Note that we have to upload the data to the server!!!
Ard Kastrati's avatar
Ard Kastrati committed
17
18
# config['data_dir'] = '/cluster/home/your_username/data/'
config['data_dir'] = './'
19
20
# Path of root
config['root_dir'] = '.'
21

22
23
24
25
26
##################################################################
# You can modify the rest or add new fields as you need.

# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4
27
28

"""
29
30
Parameters that can be chosen:

31
cnn: The simple CNN architecture
32
33
34
inception: The InceptionTime architecture
eegnet: The EEGNet architecture
deepeye: The Deep Eye architecture
35
xception: The Xception architecture
36
37
38
deepeye-rnn: The DeepEye RNN architecture

If downsampled set to true, the downsampled data (which have length 125) will be used. Otherwise, full data with length 500 is used.
okiss's avatar
okiss committed
39
If split set to true, the data will be clustered and fed each to a separate architecture. The extracted features are concatenated
40
finally used for classification.
41
Cluster can be set to clustering(), clustering2() or clustering3(), where different clusters based on literature are used.
42
43
"""

44
# Choosing model
okiss's avatar
okiss committed
45
config['model'] = 'deepeye'
Ard Kastrati's avatar
Ard Kastrati committed
46
config['downsampled'] = False
okiss's avatar
okiss committed
47
config['split'] = False
okiss's avatar
okiss committed
48
config['cluster'] = clustering()
49
50
51
if config['split']:
    config['model'] = config['model'] + '_cluster'

52
config['ensemble'] = 5 #number of models in the ensemble method
53

54
55
56
57
58
59
config['trainX_file'] = 'noweEEG.mat' if config['downsampled'] else 'all_EEGprocuesan.mat'
config['trainY_file'] = 'all_trialinfoprosan.mat'
config['trainX_variable'] = 'noweEEG' if config['downsampled'] else 'all_EEGprocuesan'
config['trainY_variable'] = 'all_trialinfoprosan'

# CNN
60
config['cnn'] = {}
61
62
63
64
# InceptionTime
config['inception'] = {}
# DeepEye
config['deepeye'] = {}
65
66
# Xception
config['xception'] = {}
67
68
# EEGNet
config['eegnet'] = {}
69
70
# DeepEye-RNN
config['deepeye-rnn'] = {}
71

72
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
73
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
74
75
config['deepeye']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye-rnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
76
77
config['xception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)

78
79
80
81
82
83
config['eegnet']['channels'] = 129
config['eegnet']['samples'] = 125 if config['downsampled'] else 500

# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']
okiss's avatar
okiss committed
84
85
if config['split']:
     model_folder_name += '_cluster'
86
87
if config['downsampled']:
    model_folder_name += '_downsampled'
okiss's avatar
okiss committed
88
89
if config['ensemble']>1:
    model_folder_name += '_ensemble'
90
91
92
93
94
95
96
97
98
99
100
101

config['model_dir'] = os.path.abspath(os.path.join(config['log_dir'], model_folder_name))
if not os.path.exists(config['model_dir']):
    os.makedirs(config['model_dir'])

config['info_log'] = config['model_dir'] + '/' + 'info.log'
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'


# Before we were loading directly the data from the server, or used the code to merge them.
# Deprecated now. We don't really use the functions anymore in the IOHelper that make use of the following configurations.

102
103
104
config['cnn']['trainX_variable1'] = "EEGprocue"
config['cnn']['trainX_variable2'] = "data"
config['cnn']['trainX_filename'] = "EEGprocue"
105

106
107
108
config['cnn']['trainY_variable1'] = "trialinfopro"
config['cnn']['trainY_variable2'] = "cues"
config['cnn']['trainY_filename'] = "trialinfocuelocked"
109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
config['inception']['trainX_variable1'] = "EEGprocue"
config['inception']['trainX_variable2'] = "data"
config['inception']['trainX_filename'] = "EEGprocue"

config['inception']['trainY_variable1'] = "trialinfopro"
config['inception']['trainY_variable2'] = "cues"
config['inception']['trainY_filename'] = "trialinfocuelocked"

config['deepeye']['trainX_variable1'] = "EEGprocue"
config['deepeye']['trainX_variable2'] = "data"
config['deepeye']['trainX_filename'] = "EEGprocue"

config['deepeye']['trainY_variable1'] = "trialinfopro"
config['deepeye']['trainY_variable2'] = "cues"
okiss's avatar
okiss committed
124
config['deepeye']['trainY_filename'] = "trialinfocuelocked"