To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 8aa74153 authored by Martyna Plomecka's avatar Martyna Plomecka
Browse files

Cross validated Random Forest

parent f090705f
......@@ -39,12 +39,12 @@ def cross_validate_SVC(X, y):
def cross_validate_RFC(X, y):
logging.info("Cross-validation RFC...")
classifier = RandomForestClassifier(max_features='auto', random_state=42, n_jobs=-1)
parameters_RFC = {'n_estimators': [10, 100, 1000, 5000], 'max_depth': [5, 10, 50, 100, 500]}
parameters_RFC = {'n_estimators': [10, 50, 100, 1000], 'max_depth': [5, 10, 50, 100, 500], 'min_samples_split' : [0.1, 0.4, 0.7, 1.0], 'min_samples_leaf' : [0.1, 0.5]}
cross_validate(classifier=classifier, parameters=parameters_RFC, X=X, y=y)
def cross_validate(classifier, parameters, X, y):
X = X.reshape((100, 500 * 129))
clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3)
X = X.reshape((36223, 500 * 129))
clf = GridSearchCV(classifier, parameters, scoring='accuracy', n_jobs=-1, verbose=3, cv=2)
clf.fit(X, y.ravel())
export_dict(clf.cv_results_['mean_fit_time'], clf.cv_results_['std_fit_time'], clf.cv_results_['mean_score_time'],
......@@ -58,20 +58,21 @@ def cross_validate(classifier, parameters, X, y):
def try_sklearn_classifiers(X, y):
logging.info("Training the simple classifiers: kNN, Linear SVM, Random Forest and Naive Bayes.")
names = ["Nearest Neighbors",
"Linear SVM",
names = [# "Nearest Neighbors",
# "Linear SVM",
"Random Forest",
"Naive Bayes",
"Linear SVM"
]
classifiers = [
KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1),
LinearSVC(tol=1e-5, C=1, random_state=42, max_iter=10000),
RandomForestClassifier(n_estimators=100, max_depth=5, max_features='auto', random_state=42, n_jobs=-1),
GaussianNB()
# KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, n_jobs=-1),
RandomForestClassifier(n_estimators=30, max_depth=20, max_features='auto', random_state=42, n_jobs=-1),
GaussianNB(),
LinearSVC(tol=1e-4, C=5, random_state=42, max_iter=500)
]
X = X.reshape((100, 500 * 129))
X = X.reshape((36223, 500 * 129))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state=42)
scores = []
......@@ -98,4 +99,4 @@ def export_dict(*columns, first_row, file_name):
writer = csv.writer(f)
writer.writerow(first_row)
for row in rows:
writer.writerow(row)
\ No newline at end of file
writer.writerow(row)
This diff is collapsed.
......@@ -49,9 +49,9 @@ def main():
if config['model'] == 'simple_classifier':
# try_sklearn_classifiers(trainX, trainY)
cross_validate_kNN(trainX, trainY)
# cross_validate_kNN(trainX, trainY)
# cross_validate_SVC(trainX, trainY)
# cross_validate_RFC(trainX, trainY)
cross_validate_RFC(trainX, trainY)
else:
# tune(trainX,trainY)
run(trainX,trainY)
......@@ -62,4 +62,4 @@ def main():
logging.info('Finished Logging')
if __name__=='__main__':
main()
\ No newline at end of file
main()
......@@ -3,7 +3,6 @@ matplotlib.use('Agg')
import pandas as pd
from config import config
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import pandas as pd
......@@ -13,7 +12,6 @@ from subprocess import call
import operator
import shutil
sns.set_style('darkgrid')
import logging
def plot_acc(hist, output_directory, model, val=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment