Commit 651f477e authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Prepared CNN and made sure it works

parent 8886ba21
......@@ -11,6 +11,12 @@ class Classifier_CNN(ConvNet):
The Classifier_CNN is one of the simplest classifiers. It implements the class ConvNet, which is made of modules with a specific depth.
"""
def __init__(self, input_shape, kernel_size=64, epochs = 50, nb_filters=16, verbose=True, batch_size=64, use_residual=True, depth=12):
super(Classifier_CNN, self).__init__(input_shape, kernel_size=kernel_size, epochs=epochs, nb_filters=nb_filters,
verbose=verbose, batch_size=batch_size, use_residual=use_residual,
depth=depth)
def _module(self, input_tensor, current_depth):
"""
The module of CNN is made of a simple convolution with batch normalization and ReLu activation. Finally, MaxPooling is also used.
......@@ -19,5 +25,5 @@ class Classifier_CNN(ConvNet):
x = tf.keras.layers.Conv1D(filters=self.nb_filters, kernel_size=self.kernel_size, padding='same', use_bias=False)(input_tensor)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation(activation='relu')(x)
x = tf.keras.layers.MaxPool1D()(x)
x = tf.keras.layers.MaxPool1D(pool_size=2, strides=1, padding='same')(x)
return x
......@@ -42,7 +42,7 @@ Cluster can be set to clustering(), clustering2() or clustering3(), where differ
"""
# Choosing model
config['model'] = 'pyramidal_cnn'
config['model'] = 'cnn'
config['downsampled'] = False
config['split'] = False
config['cluster'] = clustering()
......
......@@ -24,9 +24,9 @@ def main():
trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape)
tune(trainX,trainY)
# tune(trainX,trainY)
# run(trainX,trainY)
run(trainX,trainY)
# select_best_model()
# comparison_plot()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment