Commit ed5bf9a1 authored by zpgeng's avatar zpgeng
Browse files

Added save log and model function, did not test yet @Oriel

parent ab6cbb46
......@@ -37,11 +37,14 @@ def run(trainX, trainY):
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# train
train(trainloader=trainloader, net=net, optimizer=optimizer, criterion=criterion)
hist = train(trainloader=trainloader, net=net, optimizer=optimizer, criterion=criterion)
# save our trained model
PATH = '../cifar_net.pth'
torch.save(net.state_dict(), PATH)
# PATH = '../cifar_net.pth'
# torch.save(net.state_dict(), PATH)
# Newly added lines below
save_logs(hist, pytorch=True)
save_model_param(pytorch=True)
def train(trainloader, net, optimizer, criterion, epoch=50):
for epoch in range(2): # loop over the dataset multiple times
......
......@@ -17,6 +17,9 @@ def run(trainX,trainY):
hist = classifier.fit(deepeye_x=trainX, y=trainY)
plot_loss(hist, 'DeepEye', True)
plot_acc(hist, 'DeepEye', True)
# Newly added lines below
save_logs(hist, pytorch=False)
save_model_param(pytorch=False)
class Classifier_DEEPEYE:
"""
......
......@@ -17,6 +17,9 @@ def run(trainX, trainY):
hist = classifier.fit(trainX, trainY)
plot_loss(hist, 'EEGNet', True)
plot_acc(hist, 'EEGNet', True)
# Newly added lines below
save_logs(hist, pytorch=False)
save_model_param(pytorch=False)
class Classifier_EEGNet:
def __init__(self, output_directory, nb_classes=1, chans = 129, samples = 500, dropoutRate = 0.5, kernLength = 64, F1 = 8,
......
......@@ -10,6 +10,9 @@ def run(trainX, trainY):
hist = classifier.fit(trainX, trainY)
plot_loss(hist, 'Inception', True)
plot_acc(hist, 'Inception', True)
# Newly added lines below
save_logs(hist, pytorch=False)
save_model_param(pytorch=False)
class Classifier_INCEPTION:
def __init__(self, output_directory, input_shape, verbose=False, build=True, batch_size=64, nb_filters=32,
......
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
sns.set_style('darkgrid')
def plot_acc(hist,name,val=False):
......@@ -35,3 +36,37 @@ def plot_loss(hist,name,val=False):
plt.savefig('loss_'+name+'.png')
plt.show()
# Save the logs (newly added without debugging)
# haven't write about the pytorch=True part
def save_logs(hist, output_directory=config['model_dir'], pytorch=False):
if pytorch:
# Maybe Oriel can work on this
else:
hist_df = pd.DataFrame(hist.history)
hist_df.to_csv(output_directory + config['model'] + '_history.csv', index=False)
df_metrics = {'Accuracy': hist_df['accuracy'], 'Loss': hist_df['loss']}
df_metrics = pd.DataFrame(df_metrics)
df_metrics.to_csv(output_directory + config['model'] + '_df_metrics.csv', index=False)
index_best_model = hist_df['loss'].idxmin()
row_best_model = hist_df.loc[index_best_model]
df_best_model = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float), index=[0],
columns=['best_model_train_loss', 'best_model_val_loss', 'best_model_train_acc', 'best_model_val_acc')
df_best_model['best_model_train_loss'] = row_best_model['loss']
df_best_model['best_model_val_loss'] = row_best_model['val_loss']
df_best_model['best_model_train_acc'] = row_best_model['acc']
df_best_model['best_model_val_acc'] = row_best_model['val_acc']
df_best_model.to_csv(output_directory + config['model'] + '_df_best_model.csv', index=False)
# Save the model parameters (newly added without debugging)
def save_model_param(output_directory=config['model_dir'], pytorch=False):
if pytorch:
torch.save(net.state_dict(), output_directory + config['model'] + '_model.pth')
else:
classifier.save(output_directory + config['model'] + '_model.h5')
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment