config.py 4.53 KB
Newer Older
1
2
# configuration used by the training and evaluation methods
# let's keep it here to have a clean code on other methods that we try
3
4
import time
import os
5
6
7
from Clusters.cluster import clustering as clustering
from Clusters.cluster2 import clustering as clustering2
from Clusters.cluster3 import clustering as clustering3
8
9
10
11
12
13
14
15
16
config = dict()

##################################################################
# Please note that the following fields will be set by our scripts to re-train and re-evaluate our model.

# Where experiment results are stored.
config['log_dir'] = './runs/'
# Path to training, validation and test data folders.
# PS: Note that we have to upload the data to the server!!!
Ard Kastrati's avatar
Ard Kastrati committed
17
18
# config['data_dir'] = '/cluster/home/your_username/data/'
config['data_dir'] = './'
19
20
# Path of root
config['root_dir'] = '.'
21

22
23
24
25
26
##################################################################
# You can modify the rest or add new fields as you need.

# Hyper-parameters and training configuration.
config['learning_rate'] = 1e-4
27
28

"""
29
30
Parameters that can be chosen:

31
cnn: The simple CNN architecture
32
33
34
inception: The InceptionTime architecture
eegnet: The EEGNet architecture
deepeye: The Deep Eye architecture
35
xception: The Xception architecture
36
37
38
39
40
deepeye-rnn: The DeepEye RNN architecture

If downsampled set to true, the downsampled data (which have length 125) will be used. Otherwise, full data with length 500 is used.
If split set to true, the data will be clustered and fed each to a separate architecture. The extracted features are concatenated 
finally used for classification.
41
Cluster can be set to clustering(), clustering2() or clustering3(), where different clusters based on literature are used.
42
43
"""

44
# Choosing model
45
config['model'] = 'inception'
Ard Kastrati's avatar
Ard Kastrati committed
46
config['downsampled'] = False
47
config['split'] = False
48
config['cluster'] = clustering2()
49
50
51
if config['split']:
    config['model'] = config['model'] + '_cluster'

52
53
54
55
56
57
config['trainX_file'] = 'noweEEG.mat' if config['downsampled'] else 'all_EEGprocuesan.mat'
config['trainY_file'] = 'all_trialinfoprosan.mat'
config['trainX_variable'] = 'noweEEG' if config['downsampled'] else 'all_EEGprocuesan'
config['trainY_variable'] = 'all_trialinfoprosan'

# CNN
58
config['cnn'] = {}
59
60
61
62
# InceptionTime
config['inception'] = {}
# DeepEye
config['deepeye'] = {}
63
64
# Xception
config['xception'] = {}
65
66
# EEGNet
config['eegnet'] = {}
67
68
# DeepEye-RNN
config['deepeye-rnn'] = {}
69

70
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
71
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
72
73
config['deepeye']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye-rnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
74
75
config['xception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
config['eegnet']['channels'] = 129
config['eegnet']['samples'] = 125 if config['downsampled'] else 500

# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']

if config['downsampled']:
    model_folder_name += '_downsampled'

config['model_dir'] = os.path.abspath(os.path.join(config['log_dir'], model_folder_name))
if not os.path.exists(config['model_dir']):
    os.makedirs(config['model_dir'])

config['info_log'] = config['model_dir'] + '/' + 'info.log'
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'


# Before we were loading directly the data from the server, or used the code to merge them.
# Deprecated now. We don't really use the functions anymore in the IOHelper that make use of the following configurations.

97
98
99
config['cnn']['trainX_variable1'] = "EEGprocue"
config['cnn']['trainX_variable2'] = "data"
config['cnn']['trainX_filename'] = "EEGprocue"
100

101
102
103
config['cnn']['trainY_variable1'] = "trialinfopro"
config['cnn']['trainY_variable2'] = "cues"
config['cnn']['trainY_filename'] = "trialinfocuelocked"
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
config['inception']['trainX_variable1'] = "EEGprocue"
config['inception']['trainX_variable2'] = "data"
config['inception']['trainX_filename'] = "EEGprocue"

config['inception']['trainY_variable1'] = "trialinfopro"
config['inception']['trainY_variable2'] = "cues"
config['inception']['trainY_filename'] = "trialinfocuelocked"

config['deepeye']['trainX_variable1'] = "EEGprocue"
config['deepeye']['trainX_variable2'] = "data"
config['deepeye']['trainX_filename'] = "EEGprocue"

config['deepeye']['trainY_variable1'] = "trialinfopro"
config['deepeye']['trainY_variable2'] = "cues"
119
config['deepeye']['trainY_filename'] = "trialinfocuelocked"