ensemble.py 3.51 KB
Newer Older
okiss's avatar
okiss committed
1
import tensorflow as tf
2

okiss's avatar
okiss committed
3
4
5
6
7
from config import config
from utils.utils import *
import logging

from CNN.CNN import Classifier_CNN
8
from PyramidalCNN.PyramidalCNN import Classifier_PyramidalCNN
okiss's avatar
okiss committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from DeepEye.deepeye import Classifier_DEEPEYE
from DeepEyeRNN.deepeyeRNN import Classifier_DEEPEYE_RNN
from Xception.xception import Classifier_XCEPTION
from InceptionTime.inception import Classifier_INCEPTION
from EEGNet.eegNet import Classifier_EEGNet
import numpy as np



def run(trainX, trainY):
    """
    Starts the  multiples Classifier in the Ensemble and stores the histogram, the plots of loss and accuracy.
    validation is of the ensemble model and training just the last one
    """
    logging.info("Started running "+config['model']+". If you want to run other methods please choose another model in the config.py file.")

    # acc = tf.keras.metrics.BinaryAccuracy()
    bce = tf.keras.losses.BinaryCrossentropy()

    loss=[]
    accuracy=[]

    for i in range(config['ensemble']):
        print('beginning model number {}/{} ...'.format(i,config['ensemble']))
        if config['model'] == 'deepeye':
34
            classifier = Classifier_DEEPEYE(input_shape=config['deepeye']['input_shape'])
okiss's avatar
okiss committed
35
        elif config['model'] == 'cnn':
36
37
            classifier = Classifier_CNN(input_shape=config['cnn']['input_shape'])
        elif config['model'] == 'pyramidal_cnn':
38
            classifier = Classifier_PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=50)
okiss's avatar
okiss committed
39
        elif config['model'] == 'eegnet':
40
41
42
            classifier = Classifier_EEGNet(dropoutRate = 0.5, kernLength = 250, F1 = 16,
                                           D = 4, F2 = 256, norm_rate = 0.5, dropoutType = 'Dropout',
                                           epochs = 50)
okiss's avatar
okiss committed
43
        elif config['model'] == 'inception':
44
            classifier = Classifier_INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True,
Ard Kastrati's avatar
Ard Kastrati committed
45
                                              kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16, epochs=50)
46
        elif config['model'] == 'xception':
Ard Kastrati's avatar
Ard Kastrati committed
47
48
            classifier = Classifier_XCEPTION(input_shape=config['inception']['input_shape'], use_residual=True,
                                              kernel_size=40, nb_filters=64, depth=18, epochs=50)
okiss's avatar
okiss committed
49
50
51
52
53
54
        elif config['model'] == 'deepeye-rnn':
            classifier = Classifier_DEEPEYE_RNN(input_shape=config['deepeye-rnn']['input_shape'])
        else:
            logging.info('Cannot start the program. Please choose one model in the config.py file')

        hist, pred_ensemble = classifier.fit(trainX,trainY)
okiss's avatar
okiss committed
55
56
57
58
59
        if i == 0:
            pred = pred_ensemble.predhis
        else:
            for j, pred_epoch in enumerate(pred_ensemble.predhis):
                pred[j] = (np.array(pred[j])+np.array(pred_epoch))
okiss's avatar
okiss committed
60

okiss's avatar
okiss committed
61
62
63
64
65
    for j, pred_epoch in enumerate(pred):
        pred_epoch = (pred_epoch/config['ensemble']).tolist()
        loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
        pred_epoch = np.round(pred_epoch,0)
        accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
okiss's avatar
okiss committed
66

okiss's avatar
okiss committed
67
68
69
70
    if config['ensemble']>1:
       config['model']+='_ensemble'
    if config['split']:
        config['model'] = config['model'] + '_cluster'
okiss's avatar
okiss committed
71

okiss's avatar
okiss committed
72
73
    hist.history['val_loss'] = loss
    hist.history['val_accuracy'] = accuracy
okiss's avatar
okiss committed
74
75
    plot_loss(hist, config['model_dir'], config['model'], val = True)
    plot_acc(hist, config['model_dir'], config['model'], val = True)
okiss's avatar
okiss committed
76
    save_logs(hist, config['model_dir'], config['model'], pytorch = False)