utils.py 3.06 KB
Newer Older
1
import matplotlib.pyplot as plt
2
3
import seaborn as sns
import numpy as np
4
import torch
5
import pandas as pd
6
import os
7
8
9
10
11
12
13
14
sns.set_style('darkgrid')

def plot_acc(hist,name,val=False):
    '''
    plot the accuracy against the epochs during training
    '''
    epochs=len(hist.history['accuracy'])
    epochs=np.arange(epochs)
15
    plt.figure()
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    plt.title(name+ ' accuracy')
    plt.plot(epochs,hist.history['accuracy'],'b-',label='training')
    if val:
        plt.plot(epochs,hist.history['val_accuracy'],'g-',label='validation')
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Accuracy')
    plt.savefig('accuracy_'+name+'.png')
    plt.show()
    print(10*'*'+'\n')

def plot_loss(hist,name,val=False):
    epochs=len(hist.history['accuracy'])
    epochs=np.arange(epochs)
    plt.figure()
    plt.title(name+ ' loss')
    plt.plot(epochs,hist.history['loss'],'b-',label='training')
    if val:
        plt.plot(epochs,hist.history['val_loss'],'g-',label='validation')
    plt.legend()
    plt.xlabel('epochs')
okiss's avatar
okiss committed
37
    plt.ylabel('Binary Cross Entropy')
38
39
    plt.savefig('loss_'+name+'.png')
    plt.show()
40

okiss's avatar
okiss committed
41
42
43
44
45
46
47
48
49
50
51
52
def plot_loss_torch(loss,name='CNN'):
    epochs=np.arange(len(loss))
    plt.figure()
    plt.title(name+ ' loss')
    plt.plot(epochs,loss,'b-',label='training')
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Binary Cross Entropy')
    plt.savefig('loss_'+name+'.png')
    plt.show()


53
54
# Save the logs  
def save_logs(hist, output_directory, model, pytorch=False):
55
    if pytorch:
56
57
58
        df_metrics = {'Loss':hist}
        df_metrics = pd.DataFrame(df_metrics)
        df_metrics.to_csv(output_directory + '/' + model + '_' + 'df_metrics.csv', index=False)
okiss's avatar
okiss committed
59
        
60
61
62

    else:
        hist_df = pd.DataFrame(hist.history)
63
        hist_df.to_csv(output_directory + '/' + model + '_' + 'history.csv', index=False)
64
65
66
        
        df_metrics = {'Accuracy': hist_df['accuracy'], 'Loss': hist_df['loss']}
        df_metrics = pd.DataFrame(df_metrics)
67
        df_metrics.to_csv(output_directory + '/' + model + '_' + 'df_metrics.csv', index=False)
68
69
70
71
72

        index_best_model = hist_df['loss'].idxmin()
        row_best_model = hist_df.loc[index_best_model]

        df_best_model = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float), index=[0],
zpgeng's avatar
zpgeng committed
73
                                    columns=['best_model_train_loss', 'best_model_val_loss', 'best_model_train_acc', 'best_model_val_acc'])
74
75
76

        df_best_model['best_model_train_loss'] = row_best_model['loss']
        df_best_model['best_model_val_loss'] = row_best_model['val_loss']
Ard Kastrati's avatar
Ard Kastrati committed
77
78
        df_best_model['best_model_train_acc'] = row_best_model['accuracy']
        df_best_model['best_model_val_acc'] = row_best_model['val_accuracy']
79
        df_best_model.to_csv(output_directory + '/' + model + '_' + 'df_best_model.csv', index=False)
80
81

# Save the model parameters (newly added without debugging)
Ard Kastrati's avatar
Ard Kastrati committed
82
83
84
85
86
# def save_model_param(output_directory, model, pytorch=False):
#    if pytorch:
#        torch.save(net.state_dict(), output_directory + '/' + model + '_' + 'model.pth')
#    else:
#        classifier.save(output_directory + '/' + model + '_' + 'model.h5')
okiss's avatar
okiss committed
87