Ensemble_tf.py 6.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import tensorflow as tf
from config import config
from utils.utils import *
from utils.losses import angle_loss 
import logging

from tf_models.CNN.CNN import CNN
from tf_models.PyramidalCNN.PyramidalCNN import PyramidalCNN
from tf_models.DeepEye.deepeye import DEEPEYE
from tf_models.DeepEyeRNN.deepeyeRNN import DEEPEYE_RNN
from tf_models.Xception.Xception import XCEPTION
from tf_models.InceptionTime.Inception import INCEPTION
from tf_models.EEGNet.eegNet import EEGNet

class Ensemble_tf:
    """
    The Ensemble is a model itself, which contains a number of models that are averaged on prediction. 
    Default: nb_models of the model_type model
    Optional: Initialize with a model list, create a more versatile Ensemble model
    """

    def __init__(self, nb_models=1, model_type='cnn', model_list=[]):
        """
        nb_models: Number of models to run in the ensemble
        model_type: Default model to run 
        model_list: optional, give a list of models that should be contained in the Ensemble 
        """
        self.nb_models = nb_models
        self.model_type = model_type
        self.model_list = model_list 
        self.models = []
        self._build_ensemble_model(model_list)

    def __str__(self):
        return self.__class__.__name__

    def load_models(self, path_to_models):
        #TODO: implement
        # load all models into the model_list to predict with them 
        pass

    def _build_ensemble_model(self, model_list):
        """ 
        Create a list of compiled models
        Default: create a list of self.nb_models many models of type self.model_type
        Optional: create a list of the models described in self.model_list
        """
        if len(model_list) > 0: 
            #TODO: implement loading a list of different models 
            logging.info("Built ensemble model with model(s): {}".format(self.model_list))
            pass 
        else:
            for i in range(self.nb_models):
                self.models.append(create_model(self.model_type, model_number=i))
            logging.info("Built ensemble model with {} {} model(s)".format(self.nb_models, self.model_type))

    def run(self, X, y):
        """
        Fit all the models in the ensemble and save them to the run directory 
        """
        logging.info("Started the training")
        # Metrics 
        mse = tf.keras.losses.MeanSquaredError()
        bce = tf.keras.losses.BinaryCrossentropy()
        loss=[]
        accuracy=[]
        # Fit the models 
        for i in range(len(self.models)):
            logging.info('Start training model number {}/{} ...'.format(i+1, self.nb_models))
            model = self.models[i]
            hist, pred_ensemble = model.fit(X,y)
            # Collect the predictions on the validation sets 
            if i == 0:
                pred = pred_ensemble.predhis
            else:
                for j, pred_epoch in enumerate(pred_ensemble.predhis):
                    pred[j] = (np.array(pred[j]) + np.array(pred_epoch))
            logging.info('Finished training model number {}/{} ...'.format(i+1, self.nb_models))

        # Compute averaged metrics on the validation set 
        for j, pred_epoch in enumerate(pred):
            pred_epoch = (pred_epoch / config['ensemble']).tolist() # divide by number of ensembles to get mean prediction 
            if config['task'] == 'prosaccade_clf':
                loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
                accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
            elif config['task'] == 'angle_reg':
                loss.append(angle_loss(pred_ensemble.targets, pred_epoch).numpy())
            elif config['task'] == 'gaze-reg':
                loss.append(mse(pred_ensemble.targets, pred_epoch).numpy())
            pred_epoch = np.round(pred_epoch,0)

        # save the ensemble loss to the model directory
        loss_fname = config['model_dir'] + "/" + "ensemble_loss.txt"
        np.savetxt(loss_fname, loss, delimiter=',')

        logging.info("loss len {}".format(len(loss)))

        # Plot 
        logging.info("Creating plots")
        if(self.nb_models == 1):
            plot_loss(hist, config['model_dir'], config['model'], val = True)
        elif self.nb_models > 1:
            config['model'] += '_ensemble'
            hist.history['val_loss'] = np.array(loss)
            plot_loss(hist, config['model_dir'], config['model'], val = True)
            if config['task'] == 'prosaccade_clf':
                hist.history['val_accuracy'] = accuracy
                plot_acc(hist, config['model_dir'], config['model'], val = True)
                save_logs(hist, config['model_dir'], config['model'], pytorch = False)

        logging.info("Done with training and plotting.")
        return 

    def predict(self, X):
        """
        Predict with all models on the dataset X 
        Return the average prediction of all models in self.models
        """
        for i in range(len(self.models)):
            if i == 0:
                pred = self.models[i].model.predict(X)
            else:
                pred += self.models[i].get_model().predict(X)
        return pred / len(self.models)

def create_model(model_type, model_number):
    """
Lukas Wolf's avatar
Lukas Wolf committed
128
    Returns a compiled tensorflow model 
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    """
    if model_type == 'deepeye':
        model = DEEPEYE(input_shape=config['deepeye']['input_shape'], model_number=model_number, 
                                        epochs=config['epochs'])
    elif model_type == 'cnn':
        model = CNN(input_shape=config['cnn']['input_shape'], kernel_size=64, epochs = config['epochs'], 
                                        nb_filters=16, verbose=True, batch_size=64, use_residual=True, depth=12, 
                                        model_number=model_number)
    elif model_type == 'pyramidal_cnn':
        model = PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=config['epochs'], 
                                        model_number=model_number)
    elif model_type == 'eegnet':
        model = EEGNet(dropoutRate = 0.5, kernLength = 64, F1 = 32, D = 8, F2 = 512, norm_rate = 0.5, 
                                        dropoutType = 'Dropout', model_number=model_number, epochs=config['epochs'])
    elif model_type == 'inception':
        model = INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
                                        kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16, epochs=config['epochs'])
    elif model_type == 'xception':
        model = XCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
                                        kernel_size=40, nb_filters=64, depth=18, epochs=config['epochs'])
    elif model_type == 'deepeye-rnn':
        model = DEEPEYE_RNN(input_shape=config['deepeye-rnn']['input_shape'], model_number=model_number)
    return model