Commit db2492e5 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Added CNN with the same structure as the others

parent 0992c8c8
# -*- coding: utf-8 -*- import tensorflow as tf
import torch import tensorflow.keras as keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from utils.utils import *
from config import config from config import config
from utils.utils import *
import logging import logging
from keras.callbacks import CSVLogger
def run(trainX, trainY):
logging.info("Starting CNN.")
classifier = Classifier_CNN(input_shape=config['cnn']['input_shape'])
hist = classifier.fit(trainX, trainY)
plot_loss(hist, config['model_dir'], config['model'], True)
plot_acc(hist, config['model_dir'], config['model'], True)
save_logs(hist, config['model_dir'], config['model'], pytorch=False)
# save_model_param(classifier.model, config['model_dir'], config['model'], pytorch=False)
class Classifier_CNN:
def __init__(self, input_shape, verbose=True, build=True, batch_size=64, nb_filters=32,
use_residual=True, depth=6, kernel_size=40, nb_epochs=1500):
self.nb_filters = nb_filters
self.use_residual = use_residual
self.depth = depth
self.kernel_size = kernel_size
self.callbacks = None
self.batch_size = batch_size
self.bottleneck_size = 32
self.nb_epochs = nb_epochs
self.verbose = verbose
if build:
if config['split']:
self.model = self.split_model(input_shape)
else:
self.model = self._build_model(input_shape)
if self.verbose:
self.model.summary()
# self.model.save_weights(self.output_directory + 'model_init.hdf5')
def split_model(self, input_shape):
input_layer = tf.keras.layers.Input(input_shape)
output = []
class Net(nn.Module): # run CNN over the cluster
def __init__(self): for c in config['cluster'].keys():
super(Net, self).__init__() a = [input_shape[0]]
self.conv1 = nn.Conv1d(in_channels=129, out_channels=258, kernel_size=5, stride=1, padding=2) a.append(len(config['cluster'][c]))
self.pool = nn.MaxPool1d(5) input_shape = tuple(a)
self.conv2 = nn.Conv1d(in_channels=258, out_channels=64, kernel_size=5, stride=1, padding=2)
self.fc1 = nn.Linear(64*20, 120) output.append(self._build_model(input_shape,
self.fc2 = nn.Linear(120, 60) X=tf.transpose(tf.nn.embedding_lookup(tf.transpose(input_layer),
self.fc3 = nn.Linear(60, 2) config['cluster'][c]))))
def forward(self, x): # append the results and perform 1 dense layer with last_channel dimension and the output layer
x = self.pool(F.relu(self.conv1(x))) x = tf.keras.layers.Concatenate()(output)
x = self.pool(F.relu(self.conv2(x))) dense = tf.keras.layers.Dense(32, activation='relu')(x)
x = x.view(-1, 64*20) output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(dense)
x = F.relu(self.fc1(x)) model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
x = F.relu(self.fc2(x)) model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
x = self.fc3(x) return model
def _CNN_module(self, input_tensor, nb_filters=128, activation='linear'):
x = tf.keras.layers.Conv1D(filters=nb_filters, kernel_size=128, padding='same', activation=activation, use_bias=False)(input_tensor)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation(activation='relu')(x)
return x return x
def run(trainX, trainY):
#load the data def _build_model(self, input_shape, X=[], depth=6):
dataset = torch.utils.data.TensorDataset(trainX, trainY) if config['split']:
trainloader = torch.utils.data.DataLoader(dataset, batch_size=2) input_layer = X
else:
# define the network input_layer = tf.keras.layers.Input(input_shape)
net = Net()
x = input_layer
# define the optimizer
criterion = nn.CrossEntropyLoss() for d in range(depth):
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) x = self._CNN_module(x)
# train gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)
hist = train(trainloader=trainloader, net=net, optimizer=optimizer, criterion=criterion) if config['split']:
return gap_layer
# save our trained model output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(gap_layer)
# PATH = '../cifar_net.pth' model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
# Newly added lines below return model
save_logs(hist, config['model_dir'], config['model'], pytorch=True)
plot_loss_torch(hist) # Require debugging def fit(self, CNN_x, y):
torch.save(net.state_dict(), config['model_dir'] + '/' + config['model'] + '_' + 'model.pth') csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
# -------------- SEE BELOW ----------------------------------- early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
# Should incoporate the best model function during training ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
# @Oriel: you can check this link for reference: ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True,
# https://discuss.pytorch.org/t/save-the-best-model/30430 mode='auto')
hist = self.model.fit(CNN_x, y, verbose=1, validation_split=0.2, epochs=35,
def train(trainloader, net, optimizer, criterion, nb_epoch=50): callbacks=[csv_logger, ckpt, early_stop])
loss=[] return hist
for epoch in range(nb_epoch): # loop over the dataset multiple times
running_loss = 0.0
loss_values=[]
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels.squeeze(1))
loss.backward()
optimizer.step()
# logging.info statistics
run_loss = loss.item()
running_loss+=run_loss
loss_values.append(run_loss)
if i % 200 == 0: # logging.info every 200 mini-batches
logging.info('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
los=np.mean(loss_values)
loss.append(los)
return loss
logging.info('Finished Training')
...@@ -35,9 +35,9 @@ deepeye: Our method ...@@ -35,9 +35,9 @@ deepeye: Our method
""" """
# Choosing model # Choosing model
config['model'] = 'deepeye3' config['model'] = 'cnn'
config['downsampled'] = False config['downsampled'] = True
config['split'] = False config['split'] = True
config['cluster'] = clustering() config['cluster'] = clustering()
if config['split']: if config['split']:
config['model'] = config['model'] + '_cluster' config['model'] = config['model'] + '_cluster'
...@@ -65,6 +65,7 @@ config['eegnet'] = {} ...@@ -65,6 +65,7 @@ config['eegnet'] = {}
# LSTM-DeepEye # LSTM-DeepEye
config['deepeye-lstm'] = {} config['deepeye-lstm'] = {}
config['cnn']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129) config['inception']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye2']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129) config['deepeye2']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
config['deepeye3']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129) config['deepeye3']['input_shape'] = (125, 129) if config['downsampled'] else (500, 129)
......
...@@ -17,7 +17,7 @@ def main(): ...@@ -17,7 +17,7 @@ def main():
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True) trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
if config['model'] == 'cnn' or config['model'] == 'cnn_cluster': if config['model'] == 'cnn' or config['model'] == 'cnn_cluster':
logging.info("Started running CNN-1. If you want to run other methods please choose another model in the config.py file.") logging.info("Started running CNN. If you want to run other methods please choose another model in the config.py file.")
CNN.run(trainX, trainY) CNN.run(trainX, trainY)
elif config['model'] == 'inception' or config['model'] == 'inception_cluster': elif config['model'] == 'inception' or config['model'] == 'inception_cluster':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment