To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 4c570f13 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

small changes

parent f090705f
...@@ -39,12 +39,12 @@ def cross_validate_SVC(X, y): ...@@ -39,12 +39,12 @@ def cross_validate_SVC(X, y):
def cross_validate_RFC(X, y): def cross_validate_RFC(X, y):
logging.info("Cross-validation RFC...") logging.info("Cross-validation RFC...")
classifier = RandomForestClassifier(max_features='auto', random_state=42, n_jobs=-1) classifier = RandomForestClassifier(max_features='auto', random_state=42, n_jobs=-1)
parameters_RFC = {'n_estimators': [10, 100, 1000, 5000], 'max_depth': [5, 10, 50, 100, 500]} parameters_RFC = {'n_estimators': [10, 50, 100, 1000], 'max_depth': [5, 10, 50, 100, 500], 'min_samples_split' : [0.1, 0.4, 0.7, 1.0], 'min_samples_leaf': [0.1, 0.5]}
cross_validate(classifier=classifier, parameters=parameters_RFC, X=X, y=y) cross_validate(classifier=classifier, parameters=parameters_RFC, X=X, y=y)
def cross_validate(classifier, parameters, X, y): def cross_validate(classifier, parameters, X, y):
X = X.reshape((100, 500 * 129)) X = X.reshape((100, 500 * 129))
clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3) clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3, cv=2)
clf.fit(X, y.ravel()) clf.fit(X, y.ravel())
export_dict(clf.cv_results_['mean_fit_time'], clf.cv_results_['std_fit_time'], clf.cv_results_['mean_score_time'], export_dict(clf.cv_results_['mean_fit_time'], clf.cv_results_['std_fit_time'], clf.cv_results_['mean_score_time'],
...@@ -58,16 +58,16 @@ def cross_validate(classifier, parameters, X, y): ...@@ -58,16 +58,16 @@ def cross_validate(classifier, parameters, X, y):
def try_sklearn_classifiers(X, y): def try_sklearn_classifiers(X, y):
logging.info("Training the simple classifiers: kNN, Linear SVM, Random Forest and Naive Bayes.") logging.info("Training the simple classifiers: kNN, Linear SVM, Random Forest and Naive Bayes.")
names = ["Nearest Neighbors", names = [# "Nearest Neighbors",
"Linear SVM", #"Linear SVM",
"Random Forest", "Random Forest",
"Naive Bayes", "Naive Bayes",
] ]
classifiers = [ classifiers = [
KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1), # KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1),
LinearSVC(tol=1e-5, C=1, random_state=42, max_iter=10000), LinearSVC(tol=1e-5, C=1, random_state=42, max_iter=1000, verbose=1),
RandomForestClassifier(n_estimators=100, max_depth=5, max_features='auto', random_state=42, n_jobs=-1), RandomForestClassifier(n_estimators=100, max_depth=5, max_features='auto', random_state=42, n_jobs=-1, verbose=1),
GaussianNB() GaussianNB()
] ]
......
...@@ -49,9 +49,9 @@ def main(): ...@@ -49,9 +49,9 @@ def main():
if config['model'] == 'simple_classifier': if config['model'] == 'simple_classifier':
# try_sklearn_classifiers(trainX, trainY) # try_sklearn_classifiers(trainX, trainY)
cross_validate_kNN(trainX, trainY) # cross_validate_kNN(trainX, trainY)
# cross_validate_SVC(trainX, trainY) # cross_validate_SVC(trainX, trainY)
# cross_validate_RFC(trainX, trainY) cross_validate_RFC(trainX, trainY)
else: else:
# tune(trainX,trainY) # tune(trainX,trainY)
run(trainX,trainY) run(trainX,trainY)
...@@ -62,4 +62,4 @@ def main(): ...@@ -62,4 +62,4 @@ def main():
logging.info('Finished Logging') logging.info('Finished Logging')
if __name__=='__main__': if __name__=='__main__':
main() main()
\ No newline at end of file
...@@ -3,7 +3,6 @@ matplotlib.use('Agg') ...@@ -3,7 +3,6 @@ matplotlib.use('Agg')
import pandas as pd import pandas as pd
from config import config from config import config
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np import numpy as np
import torch import torch
import pandas as pd import pandas as pd
...@@ -13,7 +12,6 @@ from subprocess import call ...@@ -13,7 +12,6 @@ from subprocess import call
import operator import operator
import shutil import shutil
sns.set_style('darkgrid')
import logging import logging
def plot_acc(hist, output_directory, model, val=False): def plot_acc(hist, output_directory, model, val=False):
...@@ -77,7 +75,7 @@ def comparison_plot_accuracy(): ...@@ -77,7 +75,7 @@ def comparison_plot_accuracy():
plt.grid(True) plt.grid(True)
plt.xlabel('epochs') plt.xlabel('epochs')
plt.ylabel('accuracy (%)') plt.ylabel('accuracy (%)')
for experiment in os.listdir(run_dir): for experiment in os.listdir(run_dir):
name = experiment name = experiment
print(name) print(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment