CNN.py 1.32 KB
Newer Older
1
import tensorflow as tf
Ard Kastrati's avatar
Ard Kastrati committed
2
from config import config
3
from utils.utils import *
Ard Kastrati's avatar
Ard Kastrati committed
4
import logging
5
6
from ConvNet import ConvNet
from tensorflow.keras.constraints import max_norm
7
8

def run(trainX, trainY):
9
10
11
    """
    Starts the CNN and stores the histogram, the plots of loss and accuracy.
    """
12
13
14
15
16
17
    logging.info("Starting CNN.")
    classifier = Classifier_CNN(input_shape=config['cnn']['input_shape'])
    hist = classifier.fit(trainX, trainY)
    plot_loss(hist, config['model_dir'], config['model'], True)
    plot_acc(hist, config['model_dir'], config['model'], True)
    save_logs(hist, config['model_dir'], config['model'], pytorch=False)
Ard Kastrati's avatar
Ard Kastrati committed
18

19
20
21
22
class Classifier_CNN(ConvNet):
    """
    The Classifier_CNN is one of the simplest classifiers. It implements the class ConvNet, which is made of modules with a specific depth.
    """
23

24
25
26
27
    def _module(self, input_tensor, current_depth):
        """
        The module of CNN is made of a simple convolution with batch normalization and ReLu activation. Finally, MaxPooling is also used.
        """
28

29
        x = tf.keras.layers.Conv1D(filters=self.nb_filters, kernel_size=self.kernel_size, padding='same', use_bias=False)(input_tensor)
30
31
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation='relu')(x)
Ard Kastrati's avatar
Ard Kastrati committed
32
        x = tf.keras.layers.MaxPool1D()(x)
33
        return x