deepeye.py 5.65 KB
Newer Older
1
2
3
4
5
6
7
import tensorflow as tf
from config import config
from utils.utils import *
import logging
from ConvNet import ConvNet
from tensorflow.keras.constraints import max_norm

okiss's avatar
okiss committed
8

9
10
11
12
13
14
15
16
17
class Classifier_DEEPEYE(ConvNet):
    """
    The Classifier_DeepEye is the architecture that combines many ideas from InceptionTime, Xception ana EEGNet.
    It implements the class ConvNet, which is made of modules with a specific depth.
    """


    def __init__(self, input_shape, kernel_size=40, nb_filters=32, verbose=True, batch_size=64, use_residual=True,
                 depth=6, bottleneck_size=32, preprocessing=True, preprocessing_F1 = 8, preprocessing_D = 2,
Ard Kastrati's avatar
Ard Kastrati committed
18
                 preprocessing_kernLength = 250, use_simple_convolution=True, use_separable_convolution=True, epochs=1):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        """
        The DeepEye architecture has the following basic structures. It offers the possibility to do a preprocessing inspired by EEGNet.
        It is made of modules of specific depth. Each module is made the inceptionTime submodule, a separable convolution and a simple
        convolution with max pooling for stability reasons.
        """
        self.preprocessing_F1 = preprocessing_F1
        self.preprocessing_D = preprocessing_D
        self.preprocessing_kernLength = preprocessing_kernLength
        self.bottleneck_size = bottleneck_size
        self.use_simple_convolution = use_simple_convolution
        self.use_separable_convolution = use_separable_convolution
        if preprocessing: input_shape = input_shape + (1,)
        super(Classifier_DEEPEYE, self).__init__(input_shape=input_shape, kernel_size=kernel_size, nb_filters=nb_filters,
                                                 verbose=verbose, batch_size=batch_size, use_residual=use_residual,
Ard Kastrati's avatar
Ard Kastrati committed
33
                                                 depth=depth, preprocessing=preprocessing, epochs=epochs)
34

35
36
37
38
39
40
41
        if preprocessing: logging.info('--------------- preprocessing_F1         : ' + str(self.preprocessing_F1))
        if preprocessing: logging.info('--------------- preprocessing_D          : ' + str(self.preprocessing_D))
        if preprocessing: logging.info('--------------- preprocessing_kernLength : ' + str(self.preprocessing_kernLength))
        logging.info('--------------- bottleneck_size : ' + str(self.bottleneck_size))
        logging.info('--------------- use_simple_convolution : ' + str(self.use_simple_convolution))
        logging.info('--------------- use_separable_convolution   : ' + str(self.use_separable_convolution))

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    def _preprocessing(self, input_tensor):
        """
        This is the implementation of preprocessing for deepeye. It is inpired by EEGNet which offers a way to filter the signal
        into spatially specific band-pass frequencies.
        """
        print(input_tensor.shape)
        # Filter slides horizontally
        horizontal_tensor = tf.keras.layers.Conv2D(self.preprocessing_F1, (self.preprocessing_kernLength, 1), padding='same',
                                                   input_shape=input_tensor.shape[1:], use_bias=False)(input_tensor)
        horizontal_tensor = tf.keras.layers.BatchNormalization()(horizontal_tensor)

        # Filter slides vertically
        vertical_tensor = tf.keras.layers.DepthwiseConv2D((1, input_tensor.shape[2]), use_bias=False,
                                                          depth_multiplier=self.preprocessing_D,
                                                          depthwise_constraint=max_norm(1.))(horizontal_tensor)
        vertical_tensor = tf.keras.layers.BatchNormalization()(vertical_tensor)

        eeg_tensor = tf.keras.layers.Activation('elu')(vertical_tensor)
        eeg_tensor = tf.keras.layers.Dropout(0.5)(eeg_tensor)
        output_tensor = eeg_tensor[:, :, 0, :]

        return output_tensor

    def _module(self, input_tensor, current_depth):
        """
        The module of DeepEye. It starts with a bottleneck of InceptionTime which is followed by different filters with a different kernel size.
        The default values are [40,20,10]. In parallel it uses a simple convolution and a separable convolution to make use of 'extrene'
        convolutions as explained in Xception paper.
        """
        if int(input_tensor.shape[-1]) > 1:
            input_inception = tf.keras.layers.Conv1D(filters=self.bottleneck_size, kernel_size=1, padding='same', use_bias=False)(input_tensor)
        else:
            input_inception = input_tensor

        kernel_size_s = [self.kernel_size // (2 ** i) for i in range(3)]
        conv_list = []

        for i in range(len(kernel_size_s)):
            conv_list.append(
                tf.keras.layers.Conv1D(filters=self.nb_filters, kernel_size=kernel_size_s[i], padding='same', use_bias=False)(input_inception))

        max_pool_1 = tf.keras.layers.MaxPool1D(pool_size=10, strides=1, padding='same')(input_tensor)
        conv_6 = tf.keras.layers.Conv1D(filters=self.nb_filters, kernel_size=1, padding='same', use_bias=False)(max_pool_1)
        conv_list.append(conv_6)

        if self.use_simple_convolution:
            max_pool_2 = tf.keras.layers.MaxPool1D(pool_size=10, strides=1, padding='same')(input_tensor)
            conv_7 = tf.keras.layers.Conv1D(filters=self.nb_filters / 8, kernel_size=16, padding='same', use_bias=False)(max_pool_2)
            conv_list.append(conv_7)

        if self.use_separable_convolution:
            conv_8 = tf.keras.layers.SeparableConv1D(filters=self.nb_filters, kernel_size=32, padding='same', use_bias=False, depth_multiplier=1)(input_tensor)
            conv_list.append(conv_8)

        x = tf.keras.layers.Concatenate(axis=2)(conv_list)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation='relu')(x)
okiss's avatar
okiss committed
99
        return x