main.py 3.18 KB
Newer Older
1
from config import config
2
import time
3
4
from CNN import CNN
from utils import IOHelper
okiss's avatar
okiss committed
5

6
from DeepEye import deepEye, RNNdeep, deepEye2, deepEye3, Xception
7
8
from InceptionTime import inception
from EEGNet import eegNet
9
import numpy as np
Ard Kastrati's avatar
Ard Kastrati committed
10
import logging
okiss's avatar
okiss committed
11
12
import scipy
import h5py
13

14
def main():
Ard Kastrati's avatar
Ard Kastrati committed
15
16
    logging.basicConfig(filename=config['info_log'], level=logging.INFO)
    logging.info('Started the Logging')
17
    start_time = time.time()
18
    # try:
okiss's avatar
okiss committed
19
20
21
22
23
24
    # trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)

    f =  scipy.io.loadmat('trainX.mat')
    trainX = f['trainX'].reshape(-1,500,129)[:20,...]
    trainY=scipy.io.loadmat('trainY.mat')['trainY'][:20]
    print(np.shape(trainX))
25

zpgeng's avatar
zpgeng committed
26
    if config['model'] == 'cnn' or config['model'] == 'cnn_cluster':
27
        logging.info("Started running CNN. If you want to run other methods please choose another model in the config.py file.")
28
        CNN.run(trainX, trainY)
29

zpgeng's avatar
zpgeng committed
30
    elif config['model'] == 'inception' or config['model'] == 'inception_cluster':
Ard Kastrati's avatar
Ard Kastrati committed
31
        logging.info("Started running InceptionTime. If you want to run other methods please choose another model in the config.py file.")
32
        inception.run(trainX=trainX, trainY=trainY)
33

zpgeng's avatar
zpgeng committed
34
    elif config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
Ard Kastrati's avatar
Ard Kastrati committed
35
        logging.info("Started running EEGNet. If you want to run other methods please choose another model in the config.py file.")
36
        eegnet_x = np.transpose(trainX, (0, 2, 1))
Ard Kastrati's avatar
Ard Kastrati committed
37
        logging.info(eegnet_x.shape)
38
        eegNet.run(trainX=eegnet_x, trainY=trainY)
39

zpgeng's avatar
zpgeng committed
40
    elif config['model'] == 'deepeye' or config['model'] == 'deepeye_cluster':
Ard Kastrati's avatar
Ard Kastrati committed
41
        logging.info("Started running DeepEye. If you want to run other methods please choose another model in the config.py file.")
42
43
44
        deepeye_x = np.transpose(trainX, (0, 2, 1))
        logging.info(deepeye_x.shape)
        deepEye.run(trainX=deepeye_x, trainY=trainY)
45

zpgeng's avatar
zpgeng committed
46
    elif config['model'] == 'deepeye2' or config['model'] == 'deepeye2_cluster':
47
48
        logging.info("Started running DeepEye2. If you want to run other methods please choose another model in the config.py file.")
        deepEye2.run(trainX=trainX, trainY=trainY)
49

zpgeng's avatar
zpgeng committed
50
    elif config['model'] == 'deepeye3' or config['model'] == 'deepeye3_cluster':
51
52
53
        logging.info("Started running DeepEye3. If you want to run other methods please choose another model in the config.py file.")
        deepEye3.run(trainX=trainX, trainY=trainY)

zpgeng's avatar
zpgeng committed
54
    elif config['model'] == 'xception' or config['model'] == 'xception_cluster':
55
56
        logging.info("Started running XceptionTime. If you want to run other methods please choose another model in the config.py file.")
        Xception.run(trainX=trainX, trainY=trainY)
okiss's avatar
okiss committed
57

zpgeng's avatar
zpgeng committed
58
    elif config['model'] == 'deepeye-lstm' or config['model'] == 'deepeye-lstm_cluster':
zpgeng's avatar
zpgeng committed
59
60
        logging.info("Started running deepeye-lstm. If you want to run other methods please choose another model in the config.py file.")
        RNNdeep.run(trainX=trainX, trainY=trainY)
61

62
    else:
Ard Kastrati's avatar
Ard Kastrati committed
63
        logging.info('Cannot start the program. Please choose one model in the config.py file')
64

Ard Kastrati's avatar
Ard Kastrati committed
65
66
    logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
    logging.info('Finished Logging')
67

68
if __name__=='__main__':
zigeng's avatar
zigeng committed
69
    main()