CNN.py 4.15 KB
Newer Older
1
2
import tensorflow as tf
import tensorflow.keras as keras
Ard Kastrati's avatar
Ard Kastrati committed
3
from config import config
4
from utils.utils import *
Ard Kastrati's avatar
Ard Kastrati committed
5
import logging
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from keras.callbacks import CSVLogger


def run(trainX, trainY):
    logging.info("Starting CNN.")
    classifier = Classifier_CNN(input_shape=config['cnn']['input_shape'])
    hist = classifier.fit(trainX, trainY)
    plot_loss(hist, config['model_dir'], config['model'], True)
    plot_acc(hist, config['model_dir'], config['model'], True)
    save_logs(hist, config['model_dir'], config['model'], pytorch=False)
    # save_model_param(classifier.model, config['model_dir'], config['model'], pytorch=False)


class Classifier_CNN:
    def __init__(self, input_shape, verbose=True, build=True, batch_size=64, nb_filters=32,
                 use_residual=True, depth=6, kernel_size=40, nb_epochs=1500):

        self.nb_filters = nb_filters
        self.use_residual = use_residual
        self.depth = depth
        self.kernel_size = kernel_size
        self.callbacks = None
        self.batch_size = batch_size
        self.bottleneck_size = 32
        self.nb_epochs = nb_epochs
        self.verbose = verbose

        if build:
            if config['split']:
                self.model = self.split_model(input_shape)
            else:
                self.model = self._build_model(input_shape)
            if self.verbose:
                self.model.summary()
            # self.model.save_weights(self.output_directory + 'model_init.hdf5')
Ard Kastrati's avatar
Ard Kastrati committed
41

42
43
44
    def split_model(self, input_shape):
        input_layer = tf.keras.layers.Input(input_shape)
        output = []
Ard Kastrati's avatar
Ard Kastrati committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        # run CNN over the cluster
        for c in config['cluster'].keys():
            a = [input_shape[0]]
            a.append(len(config['cluster'][c]))
            input_shape = tuple(a)

            output.append(self._build_model(input_shape,
                                            X=tf.transpose(tf.nn.embedding_lookup(tf.transpose(input_layer),
                                                                                  config['cluster'][c]))))

        # append the results and perform 1 dense layer with last_channel dimension and the output layer
        x = tf.keras.layers.Concatenate()(output)
        dense = tf.keras.layers.Dense(32, activation='relu')(x)
        output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(dense)
        model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
        model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
        return model

    def _CNN_module(self, input_tensor, nb_filters=128, activation='linear'):
        x = tf.keras.layers.Conv1D(filters=nb_filters, kernel_size=128, padding='same', activation=activation, use_bias=False)(input_tensor)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation='relu')(x)
Ard Kastrati's avatar
Ard Kastrati committed
68
69
        return x

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def _build_model(self, input_shape, X=[], depth=6):
        if config['split']:
            input_layer = X
        else:
            input_layer = tf.keras.layers.Input(input_shape)

        x = input_layer

        for d in range(depth):
            x = self._CNN_module(x)

        gap_layer = tf.keras.layers.GlobalAveragePooling1D()(x)
        if config['split']:
            return gap_layer
        output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(gap_layer)
        model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
        model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
        return model

    def fit(self, CNN_x, y):
        csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
        early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
        ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
        ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True,
                                                  mode='auto')
        hist = self.model.fit(CNN_x, y, verbose=1, validation_split=0.2, epochs=35,
                              callbacks=[csv_logger, ckpt, early_stop])
        return hist