Commit 51ba7fe3 authored by Martyna Plomecka's avatar Martyna Plomecka
Browse files

Merged conflicts

parents 10f5890a bf70e31c
......@@ -18,16 +18,14 @@ import logging
from sklearn.model_selection import train_test_split
class prediction_history(tf.keras.callbacks.Callback):
def __init__(self, val_data):
self.val_data = val_data
def __init__(self, validation_data):
self.validation_data = validation_data
self.predhis = []
self.targets = []
self.targets = validation_data[1]
def on_batch_end(self, epoch, logs={}):
x_val, y_val = self.val_data
self.targets.append(y_val)
prediction = self.model.predict(x_val)
self.predhis.append(prediction)
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
self.predhis.append(y_pred)
class Classifier_EEGNet:
"""
......
import tensorflow as tf
from config import config
from utils.utils import *
import logging
from ConvNet import ConvNet
from tensorflow.keras.constraints import max_norm
class Classifier_PyramidalCNN(ConvNet):
"""
The Classifier_PyramidalCNN is one of the simplest classifiers. It implements the class ConvNet, which is made of modules with a
specific depth, where for each depth the number of filters is increased.
"""
def __init__(self, input_shape, kernel_size=16, epochs = 50, nb_filters=16, verbose=True, batch_size=64, use_residual=False, depth=6):
super(Classifier_PyramidalCNN, self).__init__(input_shape, kernel_size=kernel_size, epochs=epochs, nb_filters=nb_filters,
verbose=verbose, batch_size=batch_size, use_residual=use_residual, depth=depth)
def _module(self, input_tensor, current_depth):
"""
The module of CNN is made of a simple convolution with batch normalization and ReLu activation. Finally, MaxPooling is also used.
"""
x = tf.keras.layers.Conv1D(filters=self.nb_filters*(current_depth + 1), kernel_size=self.kernel_size, padding='same', use_bias=False)(input_tensor)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation(activation='relu')(x)
x = tf.keras.layers.MaxPool1D(pool_size=2, strides=2)(x)
return x
......@@ -42,7 +42,7 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ
"""
# Choosing model
config['model'] = 'eegnet'
config['model'] = 'pyramidal_cnn'
config['downsampled'] = False
config['split'] = False
config['cluster'] = clustering()
......@@ -57,6 +57,8 @@ config['trainY_variable'] = 'all_trialinfoprosan'
# CNN
config['cnn'] = {}
# CNN
config['pyramidal_cnn'] = {}
# InceptionTime
config['inception'] = {}
# DeepEye
......@@ -69,6 +71,7 @@ config['eegnet'] = {}
config['deepeye-rnn'] = {}
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['pyramidal_cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye-rnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
......
import tensorflow as tf
from config import config
from utils.utils import *
import logging
from CNN.CNN import Classifier_CNN
from PyramidalCNN.PyramidalCNN import Classifier_PyramidalCNN
from DeepEye.deepeye import Classifier_DEEPEYE
from DeepEyeRNN.deepeyeRNN import Classifier_DEEPEYE_RNN
from Xception.xception import Classifier_XCEPTION
......@@ -29,15 +31,19 @@ def run(trainX, trainY):
for i in range(config['ensemble']):
print('beginning model number {}/{} ...'.format(i,config['ensemble']))
if config['model'] == 'deepeye':
classifier = Classifier_DEEPEYE(input_shape = config['deepeye']['input_shape'])
classifier = Classifier_DEEPEYE(input_shape=config['deepeye']['input_shape'])
elif config['model'] == 'cnn':
classifier = Classifier_CNN(input_shape = config['cnn']['input_shape'])
classifier = Classifier_CNN(input_shape=config['cnn']['input_shape'])
elif config['model'] == 'pyramidal_cnn':
classifier = Classifier_PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=50)
elif config['model'] == 'eegnet':
classifier = Classifier_EEGNet()
classifier = Classifier_EEGNet(dropoutRate = 0.5, kernLength = 250, F1 = 16,
D = 4, F2 = 256, norm_rate = 0.5, dropoutType = 'Dropout',
epochs = 50)
elif config['model'] == 'inception':
classifier = Classifier_INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True,
kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16, epochs=50)
elif config['model'] == 'xception' :
elif config['model'] == 'xception':
classifier = Classifier_XCEPTION(input_shape=config['inception']['input_shape'], use_residual=True,
kernel_size=40, nb_filters=64, depth=18, epochs=50)
elif config['model'] == 'deepeye-rnn':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment