eegNet.py 6.83 KB
Newer Older
1
2
import tensorflow as tf
import tensorflow.keras as keras
3
from config import config
4
5
6
7
8
9
10
11
12
13
from utils.utils import *

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
Ard Kastrati's avatar
Ard Kastrati committed
14
from keras.callbacks import CSVLogger
okiss's avatar
okiss committed
15
import numpy as np
Ard Kastrati's avatar
Ard Kastrati committed
16
import logging
17

okiss's avatar
okiss committed
18
19
20
21
22
23
24
25
26
27
28
29
30
from sklearn.model_selection import train_test_split

class prediction_history(tf.keras.callbacks.Callback):
    def __init__(self, val_data):
        self.val_data = val_data
        self.predhis = []
        self.targets = []

    def on_batch_end(self, epoch, logs={}):
        x_val, y_val = self.val_data
        self.targets.append(y_val)
        prediction = self.model.predict(x_val)
        self.predhis.append(prediction)
31
32

class Classifier_EEGNet:
33
34
35
36
37
38
    """
    The EEGNet architecture used as baseline. This is the architecture explained in the paper

    'EEGNet: A Compact Convolutional Network for EEG-based Brain-Computer Interfaces' with authors
    Vernon J. Lawhern, Amelia J. Solon, Nicholas R. Waytowich, Stephen M. Gordon, Chou P. Hung, Brent J. Lance
    """
Ard Kastrati's avatar
Ard Kastrati committed
39
40

    def __init__(self, nb_classes=1, chans = config['eegnet']['channels'],
zpgeng's avatar
zpgeng committed
41
            samples = config['eegnet']['samples'], dropoutRate = 0.5, kernLength = 250, F1 = 16,
okiss's avatar
okiss committed
42
            D = 4, F2 = 256, norm_rate = 0.5, dropoutType = 'Dropout', epochs = 50, verbose = True, build = True, X = None):
43
44
45
46
47
48
49
50
51
52
53

        self.nb_classes = nb_classes
        self.chans = chans
        self.samples = samples
        self.dropoutRate = dropoutRate
        self.kernLength = kernLength
        self.F1 = F1
        self.D = D
        self.F2 = F2
        self.norm_rate = norm_rate
        self.dropoutType = dropoutType
okiss's avatar
okiss committed
54
        self.epochs = epochs
55
        self.verbose = verbose
56
57
58
59
60
61
62
63
64
        logging.info('Parameters...')
        logging.info('--------------- chans            : ' + str(self.chans))
        logging.info('--------------- samples          : ' + str(self.samples))
        logging.info('--------------- dropoutRate      : ' + str(self.dropoutRate))
        logging.info('--------------- kernLength       : ' + str(self.kernLength))
        logging.info('--------------- F1               : ' + str(self.F1))
        logging.info('--------------- D                : ' + str(self.D))
        logging.info('--------------- F2               : ' + str(self.F2))
        logging.info('--------------- norm_rate        : ' + str(self.norm_rate))
65
66

        if build:
okiss's avatar
okiss committed
67
68
69
70
71
72
73
74
            if config['split']:
                self.model = self.split_model()
            else:
                self.model = self.build_model()
            if verbose:
                self.model.summary()

    def split_model(self):
75
76
77
78
79
        """
        This method is added to make use of clustering idea in EEGNet as well. It divides the input into different clusters.
        Then it builds a model of EEGNet for each cluster, concatenates the extracted featers and uses a Dense layer to finally
        classify the data.
        """
okiss's avatar
okiss committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        input_layer = keras.layers.Input((config['eegnet']['channels'] , config['eegnet']['samples'] ))
        output=[]

        # run inception over the cluster
        for c in config['cluster'].keys():
            output.append(self.build_model(X = tf.expand_dims(tf.transpose(tf.nn.embedding_lookup(
            tf.transpose(input_layer,(1,0,2)),config['cluster'][c]),(1,0,2)),axis=-1), c = c))

        # append the results and perform 1 dense layer with last_channel dimension and the output layer

        x = tf.keras.layers.Concatenate(axis=1)(output)
        dense=tf.keras.layers.Dense(32, activation='relu')(x)
        output_layer=tf.keras.layers.Dense(1,activation='sigmoid')(dense)

        model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)

        return model

98

okiss's avatar
okiss committed
99
    def build_model(self, X = None, c = None):
100
101
102
        """
        The model of EEGNet (Taken from the implementation of EEGNet paper).
        """
103
104
105
106
107
108
109
        if self.dropoutType == 'SpatialDropout2D':
            dropoutType = SpatialDropout2D
        elif self.dropoutType == 'Dropout':
            dropoutType = Dropout
        else:
            raise ValueError('dropoutType must be one of SpatialDropout2D '
                             'or Dropout, passed as a string.')
okiss's avatar
okiss committed
110
111
        if config['split']:
            input1 = X
okiss's avatar
okiss committed
112
            self.chans=len(config['cluster'][c])
okiss's avatar
okiss committed
113
114
        else:
            input1 = Input(shape=(self.chans, self.samples, 1))
115
116
117
118
119
120
121
122
123
124

        block1 = Conv2D(self.F1, (1, self.kernLength), padding='same',
                        input_shape=(self.chans, self.samples, 1),
                        use_bias=False)(input1)
        block1 = BatchNormalization()(block1)
        block1 = DepthwiseConv2D((self.chans, 1), use_bias=False,
                                 depth_multiplier=self.D,
                                 depthwise_constraint=max_norm(1.))(block1)
        block1 = BatchNormalization()(block1)
        block1 = Activation('elu')(block1)
125
        block1 = AveragePooling2D((1, 16))(block1)
126
127
        block1 = dropoutType(self.dropoutRate)(block1)

128
        block2 = SeparableConv2D(self.F2, (1, 64),
129
130
131
                                 use_bias=False, padding='same')(block1)
        block2 = BatchNormalization()(block2)
        block2 = Activation('elu')(block2)
132
        block2 = AveragePooling2D((1, 6))(block2)
133
134
        block2 = dropoutType(self.dropoutRate)(block2)

135
        flatten = Flatten()(block2)
okiss's avatar
okiss committed
136
137
138
139
        if config['split']:
            return flatten
        else:
            dense = Dense(self.nb_classes, name='dense',
140
                      kernel_constraint=max_norm(self.norm_rate))(flatten)
okiss's avatar
okiss committed
141
            softmax = Activation('sigmoid', name='sigmoid')(dense)
142

okiss's avatar
okiss committed
143
            return Model(inputs=input1, outputs=softmax)
144

145
    def fit(self, eegnet_x, y):
146
        self.model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
Ard Kastrati's avatar
Ard Kastrati committed
147
        csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
148
149
150
        early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
        ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
        ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto')
okiss's avatar
okiss committed
151
152
153
154
155
156
        X_train, X_val, y_train, y_val = train_test_split(eegnet_x, y, test_size=0.2, random_state=42)
        pred_ensemble = prediction_history((X_val,y_val))
        hist = self.model.fit(X_train, y_train, verbose=1, validation_data=(X_val,y_val),
        epochs=self.epochs, callbacks=[csv_logger, ckpt, early_stop,pred_ensemble])

        return hist, pred_ensemble