Commit ee8c2ef1 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Implemented a variant of XceptionTime

parent 18cf1bba
import tensorflow as tf
import numpy as np
import tensorflow.keras as keras
from config import config
from utils.utils import *
import logging
from keras.callbacks import CSVLogger
def run(trainX, trainY):
logging.info("Starting XceptionTime.")
classifier = Classifier_Xception(input_shape=config['xception']['input_shape'])
hist = classifier.fit(trainX, trainY)
plot_loss(hist, config['model_dir'], config['model'], True)
plot_acc(hist, config['model_dir'], config['model'], True)
save_logs(hist, config['model_dir'], config['model'], pytorch=False)
save_model_param(classifier.model, config['model_dir'], config['model'], pytorch=False)
class Classifier_Xception:
def __init__(self, input_shape, verbose=True, build=True, batch_size=64, nb_filters=32,
use_residual=True, depth=6, nb_epochs=1500):
self.nb_filters = nb_filters
self.use_residual = use_residual
self.depth = depth
self.callbacks = None
self.batch_size = batch_size
self.nb_epochs = nb_epochs
self.verbose = verbose
if build:
self.model = self._build_model(input_shape)
if self.verbose:
self.model.summary()
def _xception_module(self, input_tensor, nb_filters=128, activation='linear'):
x = tf.keras.layers.SeparableConv1D(filters=nb_filters, kernel_size=32, padding='same', activation=activation, use_bias=False, depth_multiplier=1)(input_tensor)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation(activation='relu')(x)
return x
def _shortcut_layer(self, input_tensor, out_tensor):
shortcut_y = tf.keras.layers.Conv1D(filters=int(out_tensor.shape[-1]), kernel_size=1, padding='same', use_bias=False)(input_tensor)
shortcut_y = tf.keras.layers.BatchNormalization()(shortcut_y)
x = keras.layers.Add()([shortcut_y, out_tensor])
x = keras.layers.Activation('relu')(x)
return x
def _build_model(self, input_shape, use_residual=True, depth=9):
input_layer = tf.keras.layers.Input(input_shape)
x = input_layer
input_res = input_layer
for d in range(depth):
x = self._xception_module(x)
if use_residual and d % 3 == 2:
x = self._shortcut_layer(input_res, x)
input_res = x
gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(gap_layer)
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
return model
def fit(self, xception_x, y):
csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
hist = self.model.fit(xception_x, y, verbose=1, validation_split=0.2, epochs=35, callbacks=[csv_logger])
return hist
......@@ -35,7 +35,7 @@ deepeye: Our method
"""
# Choosing model
config['model'] = 'deepeye2'
config['model'] = 'xception'
config['downsampled'] = False
config['split'] = False
config['cluster'] = clustering()
......@@ -53,6 +53,8 @@ config['inception'] = {}
config['deepeye'] = {}
# DeepEye2
config['deepeye2'] = {}
# Xception
config['xception'] = {}
# EEGNet
config['eegnet'] = {}
# LSTM-DeepEye
......@@ -60,6 +62,8 @@ config['deepeye-lstm'] = {}
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye2']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['xception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['eegnet']['channels'] = 129
config['eegnet']['samples'] = 125 if config['downsampled'] else 500
config['deepeye']['input_shape'] = (129, 125) if config['downsampled'] else (129, 500)
......@@ -67,9 +71,6 @@ config['deepeye']['channels'] = 129
config['deepeye']['samples'] = 125 if config['downsampled'] else 500
config['deepeye-lstm']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
# You can set descriptive experiment names or simply set empty string ''.
# config['model_name'] = 'skeleton_code'
# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']
......
......@@ -2,7 +2,7 @@ from config import config
import time
from CNN import CNN
from utils import IOHelper
from DeepEye import deepEye, RNNdeep, deepEye2
from DeepEye import deepEye, RNNdeep, deepEye2, Xception
from InceptionTime import inception
from EEGNet import eegNet
import numpy as np
......@@ -15,11 +15,6 @@ def main():
start_time = time.time()
# try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
# trainX, trainY = IOHelper.get_pickle_data(config['data_dir'], verbose=True)
# IOHelper.store(trainX, trainY)
#except:
# print("ERROR while loading data.")
# return
if config['model'] == 'cnn':
logging.info("Started running CNN-1. If you want to run other methods please choose another model in the config.py file.")
......@@ -44,6 +39,10 @@ def main():
elif config['model'] == 'deepeye2':
logging.info("Started running DeepEye2. If you want to run other methods please choose another model in the config.py file.")
deepEye2.run(trainX=trainX, trainY=trainY)
elif config['model'] == 'xception':
logging.info("Started running XceptionTime. If you want to run other methods please choose another model in the config.py file.")
Xception.run(trainX=trainX, trainY=trainY)
elif config['model'] == 'deepeye-lstm':
logging.info("Started running deepeye-lstm. If you want to run other methods please choose another model in the config.py file.")
......@@ -54,5 +53,6 @@ def main():
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging')
if __name__=='__main__':
main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment