utils.py 6.15 KB
Newer Older
zpgeng's avatar
zpgeng committed
1
2
import matplotlib
matplotlib.use('Agg')
okiss's avatar
okiss committed
3
4
import pandas as pd
from config import config
5
import matplotlib.pyplot as plt
6
import numpy as np
7
import torch
8
import pandas as pd
okiss's avatar
okiss committed
9

10
import os
okiss's avatar
okiss committed
11
12
13
14
from subprocess import call
import operator
import shutil

Ard Kastrati's avatar
Ard Kastrati committed
15
import logging
16

okiss's avatar
okiss committed
17
def plot_acc(hist, output_directory, model, val=False):
18
19
20
    '''
    plot the accuracy against the epochs during training
    '''
Ard Kastrati's avatar
Ard Kastrati committed
21
22
    epochs = len(hist.history['accuracy'])
    epochs = np.arange(epochs)
23
    plt.figure()
Ard Kastrati's avatar
Ard Kastrati committed
24
    plt.title(model + ' accuracy')
Ard Kastrati's avatar
Ard Kastrati committed
25
    plt.plot(epochs, hist.history['accuracy'],'b-',label='training')
26
    if val:
Ard Kastrati's avatar
Ard Kastrati committed
27
        plt.plot(epochs, hist.history['val_accuracy'],'g-',label='validation')
okiss's avatar
okiss committed
28

29
30
31
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Accuracy')
Ard Kastrati's avatar
Ard Kastrati committed
32
33
    plt.savefig(output_directory + '/' + model + '_accuracy.png')
    # plt.show()
Ard Kastrati's avatar
Ard Kastrati committed
34
35
    logging.info(10*'*'+'\n')

36

okiss's avatar
okiss committed
37
def plot_loss(hist, output_directory, model, val=False):
Ard Kastrati's avatar
Ard Kastrati committed
38
39
    epochs = len(hist.history['accuracy'])
    epochs = np.arange(epochs)
40
    plt.figure()
Ard Kastrati's avatar
Ard Kastrati committed
41
    plt.title(model + ' loss')
okiss's avatar
okiss committed
42
    plt.plot(epochs, hist.history['loss'], 'b-', label='training')
43
    if val:
okiss's avatar
okiss committed
44
45
46
        plt.plot(epochs, hist.history['val_loss'],'g-',label='validation')


47
48
    plt.legend()
    plt.xlabel('epochs')
okiss's avatar
okiss committed
49
    plt.ylabel('Binary Cross Entropy')
Ard Kastrati's avatar
Ard Kastrati committed
50
51
    plt.savefig(output_directory + '/' + model + '_loss.png')
    # plt.show()
52

Ard Kastrati's avatar
Ard Kastrati committed
53

Ard Kastrati's avatar
Ard Kastrati committed
54
def plot_loss_torch(loss, output_directory, model):
okiss's avatar
okiss committed
55
56
    epochs=np.arange(len(loss))
    plt.figure()
Ard Kastrati's avatar
Ard Kastrati committed
57
58
    plt.title(model + ' loss')
    plt.plot(epochs, loss, 'b-', label='training')
okiss's avatar
okiss committed
59
60
61
    plt.legend()
    plt.xlabel('epochs')
    plt.ylabel('Binary Cross Entropy')
Ard Kastrati's avatar
Ard Kastrati committed
62
63
    plt.savefig(output_directory + '/' + model + 'loss.png')
    # plt.show()
okiss's avatar
okiss committed
64
65


okiss's avatar
okiss committed
66
67
68
def cp_dir(source, target):
    call(['cp', '-a', source, target])

Ard Kastrati's avatar
Ard Kastrati committed
69
def comparison_plot_accuracy():
okiss's avatar
okiss committed
70

Ard Kastrati's avatar
Ard Kastrati committed
71
    run_dir = './results/ETRA/'
Ard Kastrati's avatar
Ard Kastrati committed
72
    print(run_dir)
okiss's avatar
okiss committed
73
    plt.figure()
Ard Kastrati's avatar
Ard Kastrati committed
74
    plt.title('Comparison of the validation accuracy' )
okiss's avatar
okiss committed
75
76
77
    plt.grid(True)
    plt.xlabel('epochs')
    plt.ylabel('accuracy (%)')
Ard Kastrati's avatar
Ard Kastrati committed
78

okiss's avatar
okiss committed
79
    for experiment in os.listdir(run_dir):
Ard Kastrati's avatar
Ard Kastrati committed
80
81
        name = experiment
        print(name)
Ard Kastrati's avatar
Ard Kastrati committed
82
83
84
85
        if(name != 'eegnet'):
            summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv')
            acc = 100 * summary['val_accuracy']
            plt.plot(acc, '-' , label=name)
okiss's avatar
okiss committed
86
87

    plt.legend()
okiss's avatar
okiss committed
88
    plt.savefig(run_dir+'/comparison_accuracy.png')
okiss's avatar
okiss committed
89
90


Ard Kastrati's avatar
Ard Kastrati committed
91
def comparison_plot_loss():
Ard Kastrati's avatar
Ard Kastrati committed
92
    run_dir = './results/ETRA/'
Ard Kastrati's avatar
Ard Kastrati committed
93
94
95
96
97
    print(run_dir)
    plt.figure()
    plt.title('Comparison of the validation loss')
    plt.grid(True)
    plt.xlabel('epochs')
98
    plt.ylabel('loss')
Ard Kastrati's avatar
Ard Kastrati committed
99
100
101
102
103
104

    for experiment in os.listdir(run_dir):
        name = experiment
        print(name)
        if (name != 'eegnet'):
            summary = pd.read_csv(run_dir + experiment + '/' + name + '_history.csv')
105
            acc = summary['val_loss']
Ard Kastrati's avatar
Ard Kastrati committed
106
107
108
            plt.plot(acc, '-', label=name)

    plt.legend()
Ard Kastrati's avatar
Ard Kastrati committed
109
    plt.savefig(run_dir + '/comparison_loss.png')
Ard Kastrati's avatar
Ard Kastrati committed
110
111


okiss's avatar
okiss committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def select_best_model():
    results = {}
    model = {}
    run_dir = config['log_dir']
    #get best model in runs for all model_name
    for experiment in os.listdir(run_dir):
        number,name = experiment.split('_',1)
        if os.path.isdir(run_dir+experiment):
            try:
                summary = pd.read_csv(run_dir+experiment+'/'+name+'_df_best_model.csv')
                acc = float(summary['val_accuracy'])
                if not (name in results.keys()):

                    results[name] = acc
                    model[name] = number
                else:
                    if acc > results[name]:
                        results[name] = acc
                        model[name] = number
            except FileNotFoundError:
                pass

    # update the best model in th results folder with the ones in runs
    for name in model.keys():
        if os.path.exists(os.getcwd()+'/results/'+'best_'+name) and os.path.isdir(os.getcwd()+'/results/'+'best_'+name):
            acc = float(pd.read_csv(os.getcwd()+'/results/'+'best_'+name+'/'+name+'_df_best_model.csv')['val_accuracy'])
            if acc < results[name]:
                shutil.rmtree(os.getcwd()+'/results/'+'best_'+name)
                cp_dir(run_dir+model[name]+'_'+name,os.getcwd()+'/results/')
                os.rename(os.getcwd()+'/results/'+model[name]+'_'+name, os.getcwd()+'/results/'+'best_'+name)
            else:
                pass


okiss's avatar
okiss committed
146
# Save the logs
147
def save_logs(hist, output_directory, model, pytorch=False):
Ard Kastrati's avatar
Ard Kastrati committed
148
    # os.mkdir(output_directory)
149
    if pytorch:
zpgeng's avatar
zpgeng committed
150
151
152
153
        try:
            hist_df = pd.DataFrame(hist)
            hist_df.to_csv(output_directory + '/' + model + '_' + 'history.csv', index=False)
        except:
Ard Kastrati's avatar
Ard Kastrati committed
154
            return
155
156

    else:
zpgeng's avatar
zpgeng committed
157
158
159
        try:
            hist_df = pd.DataFrame(hist.history)
            hist_df.to_csv(output_directory + '/' + model + '_' + 'history.csv', index=False)
okiss's avatar
okiss committed
160

zpgeng's avatar
zpgeng committed
161
162
163
            #df_metrics = {'Accuracy': hist_df['accuracy'], 'Loss': hist_df['loss']}
            #df_metrics = pd.DataFrame(df_metrics)
            #df_metrics.to_csv(output_directory + '/' + model + '_' + 'df_metrics.csv', index=False)
164

zpgeng's avatar
zpgeng committed
165
            index_best_model = hist_df['val_accuracy'].idxmax()
zpgeng's avatar
zpgeng committed
166
            row_best_model = hist_df.loc[index_best_model]
167

zpgeng's avatar
zpgeng committed
168
169
            df_best_model = pd.DataFrame(data=np.zeros((1, 4), dtype=np.float), index=[0],
                    columns=['best_model_train_loss', 'best_model_val_loss', 'best_model_train_acc', 'best_model_val_acc'])
okiss's avatar
okiss committed
170

zpgeng's avatar
zpgeng committed
171
172
            df_best_model['best_model_train_loss'] = row_best_model['loss']
            df_best_model['best_model_val_loss'] = row_best_model['val_loss']
zpgeng's avatar
zpgeng committed
173
174
            df_best_model['best_model_train_acc'] = row_best_model['accuracy']
            df_best_model['best_model_val_acc'] = row_best_model['val_accuracy']
okiss's avatar
okiss committed
175

zpgeng's avatar
zpgeng committed
176
177
178
            df_best_model.to_csv(output_directory + '/' + model + '_' + 'df_best_model.csv', index=False)
        except:
            return
179

Ard Kastrati's avatar
Ard Kastrati committed
180

181
# Save the model parameters (newly added without debugging)
182
183
184
185
186
187
188
# def save_model_param(classifier, output_directory, model, pytorch=False):
#     try:
#         if pytorch:
#             torch.save(classifier.state_dict(), output_directory + '/' + model + '_' + 'model.pth')
#         else:
#             classifier.save(output_directory + '/' + model + '_' + 'model.h5')
#     except:
okiss's avatar
okiss committed
189
#         return