Commit 127a5b95 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Added loss plot

parent 74c226cc
...@@ -2,7 +2,7 @@ from config import config ...@@ -2,7 +2,7 @@ from config import config
from ensemble import run from ensemble import run
import numpy as np import numpy as np
import scipy import scipy
from utils.utils import select_best_model, comparison_plot from utils.utils import select_best_model, comparison_plot, comparison_plot_loss
from utils import IOHelper from utils import IOHelper
from scipy import io from scipy import io
import h5py import h5py
...@@ -18,17 +18,17 @@ def main(): ...@@ -18,17 +18,17 @@ def main():
logging.info('Started the Logging') logging.info('Started the Logging')
start_time = time.time() start_time = time.time()
# try: # try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True) # trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster': # if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
trainX = np.transpose(trainX, (0, 2, 1)) # trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape) # logging.info(trainX.shape)
# tune(trainX,trainY) # tune(trainX,trainY)
run(trainX,trainY) # run(trainX,trainY)
# select_best_model() # select_best_model()
# comparison_plot() comparison_plot_loss()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time)) logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging') logging.info('Finished Logging')
......
...@@ -81,14 +81,36 @@ def comparison_plot(): ...@@ -81,14 +81,36 @@ def comparison_plot():
for experiment in os.listdir(run_dir): for experiment in os.listdir(run_dir):
name = experiment name = experiment
print(name) print(name)
summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv') if(name != 'eegnet'):
acc = 100 * summary['val_accuracy'] summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv')
plt.plot(acc, '-' , label=name) acc = 100 * summary['val_accuracy']
plt.plot(acc, '-' , label=name)
plt.legend() plt.legend()
plt.savefig(run_dir+'/comparison_accuracy.png') plt.savefig(run_dir+'/comparison_accuracy.png')
def comparison_plot_loss():
run_dir = './results/OHBM/'
print(run_dir)
plt.figure()
plt.title('Comparison of the validation loss')
plt.grid(True)
plt.xlabel('epochs')
plt.ylabel('loss (%)')
for experiment in os.listdir(run_dir):
name = experiment
print(name)
if (name != 'eegnet'):
summary = pd.read_csv(run_dir + experiment + '/' + name + '_history.csv')
acc = 100 * summary['val_loss']
plt.plot(acc, '-', label=name)
plt.legend()
plt.savefig(run_dir + '/comparison_accuracy_loss.png')
def select_best_model(): def select_best_model():
results = {} results = {}
model = {} model = {}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment