utils.py 1.05 KB
Newer Older
1
import matplotlib.pyplot as plt
2
3
4
5
6
7
8
9
10
11
import seaborn as sns
import numpy as np
sns.set_style('darkgrid')

def plot_acc(hist,name,val=False):
    '''
    plot the accuracy against the epochs during training
    '''
    epochs=len(hist.history['accuracy'])
    epochs=np.arange(epochs)
12
    plt.figure()
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    plt.title(name+ ' accuracy')
    plt.plot(epochs,hist.history['accuracy'],'b-',label='training')
    if val:
        plt.plot(epochs,hist.history['val_accuracy'],'g-',label='validation')
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Accuracy')
    plt.savefig('accuracy_'+name+'.png')
    plt.show()
    print(10*'*'+'\n')

def plot_loss(hist,name,val=False):
    epochs=len(hist.history['accuracy'])
    epochs=np.arange(epochs)
    plt.figure()
    plt.title(name+ ' loss')
    plt.plot(epochs,hist.history['loss'],'b-',label='training')
    if val:
        plt.plot(epochs,hist.history['val_loss'],'g-',label='validation')
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Binary Cross Entropie')
    plt.savefig('loss_'+name+'.png')
    plt.show()
37