To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 2595cdf6 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Implemented simple classifiers from Sklearn and cross validation

parent 008292db
# Some baselines to test
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC, LinearSVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
import logging
import time
import csv
from sklearn.metrics import accuracy_score
import sys
from config import config
def cross_validate_kNN(X, y):
logging.info("Cross-validation KNN...")
classifier = KNeighborsClassifier(weights='uniform', algorithm='auto', n_jobs=-1)
parameters_KNN = {'n_neighbors': [5, 10, 50], 'leaf_size': [10, 30]}
cross_validate(classifier=classifier, parameters=parameters_KNN, X=X, y=y)
def cross_validate_SVC(X, y):
logging.info("Cross-validation SVC...")
classifier = LinearSVC(tol=1e-5)
parameters_SVC = {'C': [1, 10, 100, 1000, 10000], 'max_iter': [1000, 10000, 100000]}
cross_validate(classifier=classifier, parameters=parameters_SVC, X=X, y=y)
def cross_validate_RFC(X, y):
logging.info("Cross-validation RFC...")
classifier = RandomForestClassifier(max_features='auto', random_state=42, n_jobs=-1)
parameters_RFC = {'n_estimators': [10, 100, 1000, 5000], 'max_depth': [5, 10, 50, 100, 500]}
cross_validate(classifier=classifier, parameters=parameters_RFC, X=X, y=y)
def cross_validate(classifier, parameters, X, y):
X = X.reshape((100, 500 * 129))
clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3)
clf.fit(X, y.ravel())
export_dict(clf.cv_results_['mean_fit_time'], clf.cv_results_['std_fit_time'], clf.cv_results_['mean_score_time'],
clf.cv_results_['std_score_time'], clf.cv_results_['mean_test_score'], clf.cv_results_['std_test_score'], \
clf.cv_results_['params'], file_name='CrossValidation_results.csv',
first_row=('mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time',
'mean_test_score', 'std_test_score', 'params'))
best_columns = ('BEST ESTIMATOR', 'BEST SCORE', 'BEST PARAMS')
export_dict([clf.best_estimator_], [clf.best_score_], [clf.best_params_], file_name='best_results.csv', first_row=best_columns)
def try_sklearn_classifiers(X, y):
logging.info("Training the simple classifiers: kNN, Linear SVM, Random Forest and Naive Bayes.")
names = ["Nearest Neighbors",
"Linear SVM",
"Random Forest",
"Naive Bayes",
]
classifiers = [
KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1),
LinearSVC(tol=1e-5, C=1, random_state=42, max_iter=10000),
RandomForestClassifier(n_estimators=100, max_depth=5, max_features='auto', random_state=42, n_jobs=-1),
GaussianNB()
]
X = X.reshape((100, 500 * 129))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state=42)
scores = []
runtimes = []
for name, clf in zip(names, classifiers):
logging.info(name)
start_time = time.time()
clf.fit(X_train, y_train.ravel())
score = clf.score(X_test, y_test.ravel())
scores.append(score)
runtime = (time.time() - start_time)
runtimes.append(runtime)
logging.info("--- Score: %s " % score)
logging.info("--- Runtime: %s for seconds ---" % runtime)
export_dict(names, scores, runtimes, first_row=('Model', 'Score', 'Runtime'), file_name='classifiers_results.csv')
def export_dict(*columns, first_row, file_name):
rows = zip(*columns)
file = config['model_dir'] + '/' + file_name
with open(file, "w") as f:
writer = csv.writer(f)
writer.writerow(first_row)
for row in rows:
writer.writerow(row)
\ No newline at end of file
......@@ -34,6 +34,7 @@ eegnet: The EEGNet architecture
deepeye: The Deep Eye architecture
xception: The Xception architecture
deepeye-rnn: The DeepEye RNN architecture
simple_classifier: Tries the simple classifiers from sklearn.
If downsampled set to true, the downsampled data (which have length 125) will be used. Otherwise, full data with length 500 is used.
If split set to true, the data will be clustered and fed each to a separate architecture. The extracted features are concatenated
......@@ -42,7 +43,7 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ
"""
# Choosing model
config['model'] = 'xception'
config['model'] = 'simple_classifier'
config['downsampled'] = False
config['split'] = False
config['cluster'] = clustering()
......@@ -83,11 +84,12 @@ config['eegnet']['samples'] = 125 if config['downsampled'] else 500
# Create a unique output directory for this experiment.
timestamp = str(int(time.time()))
model_folder_name = timestamp if config['model'] == '' else timestamp + "_" + config['model']
if config['split']:
if config['split'] and config['model'] != 'simple_classifier':
model_folder_name += '_cluster'
if config['downsampled']:
if config['downsampled'] and config['model'] != 'simple_classifier':
model_folder_name += '_downsampled'
if config['ensemble']>1:
if config['ensemble'] > 1 and config['model'] != 'simple_classifier':
model_folder_name += '_ensemble'
config['model_dir'] = os.path.abspath(os.path.join(config['log_dir'], model_folder_name))
......@@ -97,7 +99,6 @@ if not os.path.exists(config['model_dir']):
config['info_log'] = config['model_dir'] + '/' + 'info.log'
config['batches_log'] = config['model_dir'] + '/' + 'batches.log'
# Before we were loading directly the data from the server, or used the code to merge them.
# Deprecated now. We don't really use the functions anymore in the IOHelper that make use of the following configurations.
......
......@@ -30,13 +30,17 @@ def run(trainX, trainY):
for i in range(config['ensemble']):
print('beginning model number {}/{} ...'.format(i,config['ensemble']))
if config['model'] == 'deepeye':
classifier = Classifier_DEEPEYE(input_shape=config['deepeye']['input_shape'])
elif config['model'] == 'cnn':
classifier = Classifier_CNN(input_shape=config['cnn']['input_shape'], kernel_size=64, epochs = 50,
nb_filters=16, verbose=True, batch_size=64, use_residual=True, depth=12)
elif config['model'] == 'pyramidal_cnn':
classifier = Classifier_PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=50)
elif config['model'] == 'eegnet':
classifier = Classifier_EEGNet(dropoutRate = 0.5, kernLength = 64, F1 = 32,
D = 8, F2 = 512, norm_rate = 0.5, dropoutType = 'Dropout',
......@@ -49,6 +53,7 @@ def run(trainX, trainY):
kernel_size=40, nb_filters=64, depth=18, epochs=50)
elif config['model'] == 'deepeye-rnn':
classifier = Classifier_DEEPEYE_RNN(input_shape=config['deepeye-rnn']['input_shape'])
else:
logging.info('Cannot start the program. Please choose one model in the config.py file')
......
......@@ -68,7 +68,8 @@ def build_model(hp):
elif config['model'] == 'deepeye-rnn':
classifier = Classifier_DEEPEYE_RNN(input_shape=config['deepeye-rnn']['input_shape'])
else:
logging.info('Cannot start the program. Please choose one model in the config.py file')
logging.info('Cannot start the program. The model chosen in the config.py file '
'either does not exist or it is (still) not supported for the tuner.')
return classifier.get_model()
......
from scipy.io import savemat
from config import config
from ensemble import run
from SimpleClassifiers.sklearnclassifier import try_sklearn_classifiers, cross_validate_kNN, cross_validate_RFC, \
cross_validate_SVC
import numpy as np
import scipy
from utils.utils import select_best_model, comparison_plot_loss, comparison_plot_accuracy
......@@ -13,25 +17,49 @@ from kerasTuner import tune
from utils import IOHelper
import sys
class Tee(object):
def __init__(self, *files):
self.files = files
def write(self, obj):
for f in self.files:
f.write(obj)
f.flush() # If you want the output to be visible immediately
def flush(self) :
for f in self.files:
f.flush()
def main():
# For Logging the data
logging.basicConfig(filename=config['info_log'], level=logging.INFO)
logging.info('Started the Logging')
start_time = time.time()
# try:
# trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
# if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
# trainX = np.transpose(trainX, (0, 2, 1))
# logging.info(trainX.shape)
# For being able to see progress that some methods use verbose
f = open(config['model_dir'] + '/console.out', 'w')
original = sys.stdout
sys.stdout = Tee(sys.stdout, f)
# tune(trainX,trainY)
# Start of the main program
start_time = time.time()
# try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
print(trainX.shape)
print(trainY.shape)
# run(trainX,trainY)
# select_best_model()
comparison_plot_accuracy()
if config['model'] == 'simple_classifier':
# try_sklearn_classifiers(trainX, trainY)
cross_validate_kNN(trainX, trainY)
# cross_validate_SVC(trainX, trainY)
# cross_validate_RFC(trainX, trainY)
else:
# tune(trainX,trainY)
run(trainX,trainY)
# comparison_plot_accuracy()
# comparison_plot_loss()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging')
if __name__=='__main__':
main()
main()
\ No newline at end of file
......@@ -22,7 +22,8 @@ def get_mat_data(data_dir, verbose=True):
if verbose: logging.info("Setting the shapes")
X = np.transpose(X, (2, 1, 0))
y = np.transpose(y, (1, 0))
if config['downsampled']: X = np.transpose(X, (0, 2, 1))
if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
X = np.transpose(X, (0, 2, 1))
if verbose:
logging.info(X.shape)
logging.info(y.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment