ensemble.py 2.99 KB
Newer Older
okiss's avatar
okiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf
from config import config
from utils.utils import *
import logging

from CNN.CNN import Classifier_CNN
from DeepEye.deepeye import Classifier_DEEPEYE
from DeepEyeRNN.deepeyeRNN import Classifier_DEEPEYE_RNN
from Xception.xception import Classifier_XCEPTION
from InceptionTime.inception import Classifier_INCEPTION
from EEGNet.eegNet import Classifier_EEGNet
import numpy as np



def run(trainX, trainY):
    """
    Starts the  multiples Classifier in the Ensemble and stores the histogram, the plots of loss and accuracy.
    validation is of the ensemble model and training just the last one
    """
    logging.info("Started running "+config['model']+". If you want to run other methods please choose another model in the config.py file.")

    # acc = tf.keras.metrics.BinaryAccuracy()
    bce = tf.keras.losses.BinaryCrossentropy()

    loss=[]
    accuracy=[]

    for i in range(config['ensemble']):
        print('beginning model number {}/{} ...'.format(i,config['ensemble']))
        if config['model'] == 'deepeye':
            classifier = Classifier_DEEPEYE(input_shape = config['deepeye']['input_shape'])
        elif config['model'] == 'cnn':
            classifier = Classifier_CNN(input_shape = config['cnn']['input_shape'])
        elif config['model'] == 'eegnet':
            classifier = Classifier_EEGNet()
        elif config['model'] == 'inception':
38
39
            classifier = Classifier_INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True,
                                              kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16)
okiss's avatar
okiss committed
40
41
42
43
44
45
46
47
        elif config['model'] == 'xception' :
            classifier = Classifier_XCEPTION(input_shape=config['inception']['input_shape'])
        elif config['model'] == 'deepeye-rnn':
            classifier = Classifier_DEEPEYE_RNN(input_shape=config['deepeye-rnn']['input_shape'])
        else:
            logging.info('Cannot start the program. Please choose one model in the config.py file')

        hist, pred_ensemble = classifier.fit(trainX,trainY)
okiss's avatar
okiss committed
48
49
50
51
52
        if i == 0:
            pred = pred_ensemble.predhis
        else:
            for j, pred_epoch in enumerate(pred_ensemble.predhis):
                pred[j] = (np.array(pred[j])+np.array(pred_epoch))
okiss's avatar
okiss committed
53

okiss's avatar
okiss committed
54
55
56
57
58
    for j, pred_epoch in enumerate(pred):
        pred_epoch = (pred_epoch/config['ensemble']).tolist()
        loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
        pred_epoch = np.round(pred_epoch,0)
        accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
okiss's avatar
okiss committed
59

okiss's avatar
okiss committed
60
61
62
63
    if config['ensemble']>1:
       config['model']+='_ensemble'
    if config['split']:
        config['model'] = config['model'] + '_cluster'
okiss's avatar
okiss committed
64

okiss's avatar
okiss committed
65
66
    hist.history['val_loss'] = loss
    hist.history['val_accuracy'] = accuracy
okiss's avatar
okiss committed
67
68
    plot_loss(hist, config['model_dir'], config['model'], val = True)
    plot_acc(hist, config['model_dir'], config['model'], val = True)
okiss's avatar
okiss committed
69
    save_logs(hist, config['model_dir'], config['model'], pytorch = False)