To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 4c570f13 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

small changes

parent f090705f
......@@ -39,12 +39,12 @@ def cross_validate_SVC(X, y):
def cross_validate_RFC(X, y):
logging.info("Cross-validation RFC...")
classifier = RandomForestClassifier(max_features='auto', random_state=42, n_jobs=-1)
parameters_RFC = {'n_estimators': [10, 100, 1000, 5000], 'max_depth': [5, 10, 50, 100, 500]}
parameters_RFC = {'n_estimators': [10, 50, 100, 1000], 'max_depth': [5, 10, 50, 100, 500], 'min_samples_split' : [0.1, 0.4, 0.7, 1.0], 'min_samples_leaf': [0.1, 0.5]}
cross_validate(classifier=classifier, parameters=parameters_RFC, X=X, y=y)
def cross_validate(classifier, parameters, X, y):
X = X.reshape((100, 500 * 129))
clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3)
clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3, cv=2)
clf.fit(X, y.ravel())
export_dict(clf.cv_results_['mean_fit_time'], clf.cv_results_['std_fit_time'], clf.cv_results_['mean_score_time'],
......@@ -58,16 +58,16 @@ def cross_validate(classifier, parameters, X, y):
def try_sklearn_classifiers(X, y):
logging.info("Training the simple classifiers: kNN, Linear SVM, Random Forest and Naive Bayes.")
names = ["Nearest Neighbors",
"Linear SVM",
names = [# "Nearest Neighbors",
#"Linear SVM",
"Random Forest",
"Naive Bayes",
]
classifiers = [
KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1),
LinearSVC(tol=1e-5, C=1, random_state=42, max_iter=10000),
RandomForestClassifier(n_estimators=100, max_depth=5, max_features='auto', random_state=42, n_jobs=-1),
# KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1),
LinearSVC(tol=1e-5, C=1, random_state=42, max_iter=1000, verbose=1),
RandomForestClassifier(n_estimators=100, max_depth=5, max_features='auto', random_state=42, n_jobs=-1, verbose=1),
GaussianNB()
]
......
......@@ -49,9 +49,9 @@ def main():
if config['model'] == 'simple_classifier':
# try_sklearn_classifiers(trainX, trainY)
cross_validate_kNN(trainX, trainY)
# cross_validate_kNN(trainX, trainY)
# cross_validate_SVC(trainX, trainY)
# cross_validate_RFC(trainX, trainY)
cross_validate_RFC(trainX, trainY)
else:
# tune(trainX,trainY)
run(trainX,trainY)
......@@ -62,4 +62,4 @@ def main():
logging.info('Finished Logging')
if __name__=='__main__':
main()
\ No newline at end of file
main()
......@@ -3,7 +3,6 @@ matplotlib.use('Agg')
import pandas as pd
from config import config
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import pandas as pd
......@@ -13,7 +12,6 @@ from subprocess import call
import operator
import shutil
sns.set_style('darkgrid')
import logging
def plot_acc(hist, output_directory, model, val=False):
......@@ -77,7 +75,7 @@ def comparison_plot_accuracy():
plt.grid(True)
plt.xlabel('epochs')
plt.ylabel('accuracy (%)')
for experiment in os.listdir(run_dir):
name = experiment
print(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment