utils.py 3.13 KB
Newer Older
1
import matplotlib.pyplot as plt
2
3
import seaborn as sns
import numpy as np
4
import torch
5
import pandas as pd
6
7
8
9
10
11
12
13
sns.set_style('darkgrid')

def plot_acc(hist,name,val=False):
    '''
    plot the accuracy against the epochs during training
    '''
    epochs=len(hist.history['accuracy'])
    epochs=np.arange(epochs)
14
    plt.figure()
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    plt.title(name+ ' accuracy')
    plt.plot(epochs,hist.history['accuracy'],'b-',label='training')
    if val:
        plt.plot(epochs,hist.history['val_accuracy'],'g-',label='validation')
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Accuracy')
    plt.savefig('accuracy_'+name+'.png')
    plt.show()
    print(10*'*'+'\n')

def plot_loss(hist,name,val=False):
    epochs=len(hist.history['accuracy'])
    epochs=np.arange(epochs)
    plt.figure()
    plt.title(name+ ' loss')
    plt.plot(epochs,hist.history['loss'],'b-',label='training')
    if val:
        plt.plot(epochs,hist.history['val_loss'],'g-',label='validation')
    plt.legend()
    plt.xlabel('epochs')
okiss's avatar
okiss committed
36
    plt.ylabel('Binary Cross Entropy')
37
38
    plt.savefig('loss_'+name+'.png')
    plt.show()
39

okiss's avatar
okiss committed
40
41
42
43
44
45
46
47
48
49
50
51
def plot_loss_torch(loss,name='CNN'):
    epochs=np.arange(len(loss))
    plt.figure()
    plt.title(name+ ' loss')
    plt.plot(epochs,loss,'b-',label='training')
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Binary Cross Entropy')
    plt.savefig('loss_'+name+'.png')
    plt.show()


52
53
54
55
# Save the logs (newly added without debugging)
# haven't write about the pytorch=True part 
def save_logs(hist, output_directory=config['model_dir'], pytorch=False):
    if pytorch:
okiss's avatar
okiss committed
56
57
58
59
        df_metrics={'Loss':hist}
        df_metrics=pd.DataFrame(df_metrics)
        df_metrics.to_csv(output_directory + config['model'] + '_df_metrics.csv', index=False)
        
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    else:
        hist_df = pd.DataFrame(hist.history)
        hist_df.to_csv(output_directory + config['model'] + '_history.csv', index=False)
        
        df_metrics = {'Accuracy': hist_df['accuracy'], 'Loss': hist_df['loss']}
        df_metrics = pd.DataFrame(df_metrics)
        df_metrics.to_csv(output_directory + config['model'] + '_df_metrics.csv', index=False)

        index_best_model = hist_df['loss'].idxmin()
        row_best_model = hist_df.loc[index_best_model]

        df_best_model = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float), index=[0],
                                    columns=['best_model_train_loss', 'best_model_val_loss', 'best_model_train_acc', 'best_model_val_acc')

        df_best_model['best_model_train_loss'] = row_best_model['loss']
        df_best_model['best_model_val_loss'] = row_best_model['val_loss']
        df_best_model['best_model_train_acc'] = row_best_model['acc']
        df_best_model['best_model_val_acc'] = row_best_model['val_acc']
        
        df_best_model.to_csv(output_directory + config['model'] + '_df_best_model.csv', index=False)

# Save the model parameters (newly added without debugging)
def save_model_param(output_directory=config['model_dir'], pytorch=False):
    if pytorch:
        torch.save(net.state_dict(), output_directory + config['model'] + '_model.pth')
    else:
        classifier.save(output_directory + config['model'] + '_model.h5')
okiss's avatar
okiss committed
88