To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 0e2ae3ef authored by Feliks Kiszkurno's avatar Feliks Kiszkurno
Browse files

Fixed hyperparameter optimization

parent c82c0f98
......@@ -13,6 +13,10 @@ import slopestabilityML.run_classification
import numpy as np
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.neural_network import MLPClassifier
#from sklearn.tree import DecisionTreeClassifier
......@@ -25,7 +29,8 @@ def adaboost_run(test_results, random_seed):
if settings.settings['optimize_ml'] is True:
hyperparameters = {'base_estimator': ['SVM', 'GBC', 'KNN'],
hyperparameters = {'base_estimator': [GradientBoostingClassifier(), KNeighborsClassifier(),
MLPClassifier(), SGDClassifier()],
'n_estimators': list(np.arange(10, 90, 10)),
'learning_rate': list(np.arange(0.2, 1.5, 0.1))}
......
......@@ -6,7 +6,7 @@ Created on 19.01.2021
@author: Feliks Kiszkurno
"""
from sklearn import ensemble
from sklearn.ensemble import GradientBoostingClassifier
import numpy as np
......@@ -28,13 +28,13 @@ def gbc_run(test_results, random_seed):
'learning_rate': list(np.arange(0.2, 1.5, 0.1)),
'n_estimators': list(np.arange(70, 160, 10))}
clf_base = ensemble.GradientBoostingClassifier()
clf_base = GradientBoostingClassifier()
clf = slopestabilityML.select_search_type(clf_base, hyperparameters)
else:
# Create classifier
clf = ensemble.GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=0)
clf = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=0)
# Train classifier
results, result_class = \
......
......@@ -22,7 +22,7 @@ def knn_run(test_results, random_seed):
if settings.settings['optimize_ml'] is True:
hyperparameters = {'neighbours': list(np.arange(2, 9, 1)),
hyperparameters = {'n_neighbors': list(np.arange(2, 9, 1)),
'weights': ['uniform', 'distance'],
'p': [1, 2]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment