Commit 127a5b95 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Added loss plot

parent 74c226cc
......@@ -2,7 +2,7 @@ from config import config
from ensemble import run
import numpy as np
import scipy
from utils.utils import select_best_model, comparison_plot
from utils.utils import select_best_model, comparison_plot, comparison_plot_loss
from utils import IOHelper
from scipy import io
import h5py
......@@ -18,17 +18,17 @@ def main():
logging.info('Started the Logging')
start_time = time.time()
# try:
trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
# trainX, trainY = IOHelper.get_mat_data(config['data_dir'], verbose=True)
if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape)
# if config['model'] == 'eegnet' or config['model'] == 'eegnet_cluster':
# trainX = np.transpose(trainX, (0, 2, 1))
# logging.info(trainX.shape)
# tune(trainX,trainY)
run(trainX,trainY)
# run(trainX,trainY)
# select_best_model()
# comparison_plot()
comparison_plot_loss()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging')
......
......@@ -81,14 +81,36 @@ def comparison_plot():
for experiment in os.listdir(run_dir):
name = experiment
print(name)
summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv')
acc = 100 * summary['val_accuracy']
plt.plot(acc, '-' , label=name)
if(name != 'eegnet'):
summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv')
acc = 100 * summary['val_accuracy']
plt.plot(acc, '-' , label=name)
plt.legend()
plt.savefig(run_dir+'/comparison_accuracy.png')
def comparison_plot_loss():
run_dir = './results/OHBM/'
print(run_dir)
plt.figure()
plt.title('Comparison of the validation loss')
plt.grid(True)
plt.xlabel('epochs')
plt.ylabel('loss (%)')
for experiment in os.listdir(run_dir):
name = experiment
print(name)
if (name != 'eegnet'):
summary = pd.read_csv(run_dir + experiment + '/' + name + '_history.csv')
acc = 100 * summary['val_loss']
plt.plot(acc, '-', label=name)
plt.legend()
plt.savefig(run_dir + '/comparison_accuracy_loss.png')
def select_best_model():
results = {}
model = {}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment