Ensemble_torch.py 4.36 KB
Newer Older
1
2
3
4
from config import config
import logging
import torch
import torch.nn as nn
Lukas Wolf's avatar
Lukas Wolf committed
5
import os 
6
7
8
9
10
11
12
13
14

from torch_models.CNN.CNN import CNN

class Ensemble_torch:
    """
    The Ensemble is a model itself, which contains a number of models that are averaged on prediction. 
    Default: nb_models of the model_type model
    Optional: Initialize with a model list, create a more versatile Ensemble model
    """
Lukas Wolf's avatar
Lukas Wolf committed
15
    
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    def __init__(self, nb_models=1, model_type='cnn', model_list=[]):
        """
        nb_models: Number of models to run in the ensemble
        model_type: Default model to run 
        model_list: optional, give a list of models that should be contained in the Ensemble 
        """
        self.nb_models = nb_models
        self.model_type = model_type
        self.model_list = model_list 
        self.models = []
        #self._build_ensemble_model(model_list)

    def __str__(self):
        return self.__class__.__name__

    def load_models(self, path_to_models):
Lukas Wolf's avatar
Lukas Wolf committed
32
33
34
35
36
37
38
        for model_name in os.listdir(path_to_models):
            # load models and save them in self.models
            # Adapt nb_models 
            # Create model_list as string list for completeness 
            pass 
    
    def run(self, x, y):
39
40
41
42
43
44
        """
        Fit all the models in the ensemble and save them to the run directory 
        """
        # Create model 
        model = CNN(input_shape=config['cnn']['input_shape'])        
        logging.info("Created model")
Lukas Wolf's avatar
Lukas Wolf committed
45
46
        
        model.fit(x, y)
47

Lukas Wolf's avatar
Lukas Wolf committed
48
        #TODO: create logs, create plots, etc. 
49

Lukas Wolf's avatar
Lukas Wolf committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def _build_ensemble_model(self, model_list):
        """ 
        Create a list of compiled models
        Default: create a list of self.nb_models many models of type self.model_type
        Optional: create a list of the models described in self.model_list
        """
        if len(model_list) > 0: 
            #TODO: implement loading a list of different models 
            logging.info("Built ensemble model with model(s): {}".format(self.model_list))
            pass 
        else:
            for i in range(self.nb_models):
                self.models.append(create_model(self.model_type, model_number=i))
            logging.info("Built ensemble model with {} {} model(s)".format(self.nb_models, self.model_type))
64
65
66
67
68
69
70
71
72
73
74

    def predict(self, X):
        """
        Predict with all models on the dataset X 
        Return the average prediction of all models in self.models
        """
        for i in range(len(self.models)):
            if i == 0:
                pred = self.models[i].model.predict(X)
            else:
                pred += self.models[i].model.predict(X)
Lukas Wolf's avatar
Lukas Wolf committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        return pred / len(self.models)


def create_model(model_type, model_number):
    """
    Returns a torch model
    """
    if model_type == 'cnn':
        model = CNN(input_shape=config['cnn']['input_shape'], kernel_size=64, epochs = config['epochs'], 
                                        nb_filters=16, verbose=True, batch_size=64, use_residual=True, depth=12, 
                                        model_number=model_number)
    """
    elif model_type == 'deepeye':
        model = DEEPEYE(input_shape=config['deepeye']['input_shape'], model_number=model_number, 
                                        epochs=config['epochs'])
    elif model_type == 'pyramidal_cnn':
        model = PyramidalCNN(input_shape=config['cnn']['input_shape'], epochs=config['epochs'], 
                                        model_number=model_number)
    elif model_type == 'eegnet':
        model = EEGNet(dropoutRate = 0.5, kernLength = 64, F1 = 32, D = 8, F2 = 512, norm_rate = 0.5, 
                                        dropoutType = 'Dropout', model_number=model_number, epochs=config['epochs'])
    elif model_type == 'inception':
        model = INCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
                                        kernel_size=64, nb_filters=16, depth=12, bottleneck_size=16, epochs=config['epochs'])
    elif model_type == 'xception':
        model = XCEPTION(input_shape=config['inception']['input_shape'], use_residual=True, model_number=model_number,
                                        kernel_size=40, nb_filters=64, depth=18, epochs=config['epochs'])
    elif model_type == 'deepeye-rnn':
        model = DEEPEYE_RNN(input_shape=config['deepeye-rnn']['input_shape'], model_number=model_number)
    """
    return model