Commit 008292db authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Plotted the results

parent 0e0631e9
...@@ -2,7 +2,7 @@ from config import config ...@@ -2,7 +2,7 @@ from config import config
from ensemble import run from ensemble import run
import numpy as np import numpy as np
import scipy import scipy
from utils.utils import select_best_model, comparison_plot_loss from utils.utils import select_best_model, comparison_plot_loss, comparison_plot_accuracy
from utils import IOHelper from utils import IOHelper
from scipy import io from scipy import io
import h5py import h5py
...@@ -18,7 +18,7 @@ def main(): ...@@ -18,7 +18,7 @@ def main():
logging.info('Started the Logging') logging.info('Started the Logging')
start_time = time.time() start_time = time.time()
# try: # try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True) # trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
# if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster': # if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
# trainX = np.transpose(trainX, (0, 2, 1)) # trainX = np.transpose(trainX, (0, 2, 1))
...@@ -26,9 +26,9 @@ def main(): ...@@ -26,9 +26,9 @@ def main():
# tune(trainX,trainY) # tune(trainX,trainY)
run(trainX,trainY) # run(trainX,trainY)
# select_best_model() # select_best_model()
# comparison_plot_loss() comparison_plot_accuracy()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time)) logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging') logging.info('Finished Logging')
......
...@@ -70,7 +70,7 @@ def cp_dir(source, target): ...@@ -70,7 +70,7 @@ def cp_dir(source, target):
def comparison_plot_accuracy(): def comparison_plot_accuracy():
run_dir = './results/OHBM/' run_dir = './results/ETRA/'
print(run_dir) print(run_dir)
plt.figure() plt.figure()
plt.title('Comparison of the validation accuracy' ) plt.title('Comparison of the validation accuracy' )
...@@ -91,7 +91,7 @@ def comparison_plot_accuracy(): ...@@ -91,7 +91,7 @@ def comparison_plot_accuracy():
def comparison_plot_loss(): def comparison_plot_loss():
run_dir = './results/OHBM/' run_dir = './results/ETRA/'
print(run_dir) print(run_dir)
plt.figure() plt.figure()
plt.title('Comparison of the validation loss') plt.title('Comparison of the validation loss')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment