Commit 4b301b74 authored by PhilFischer's avatar PhilFischer
Browse files

Added output layer activation as hyperparameter

parent 1dc237e5
......@@ -26,7 +26,7 @@ class Encoder(Layer):
class Decoder(Layer):
def __init__(self, output_shape, latent_dim=32, activation='swish', layers=1, filters=4, kernel_size=3, multiple=1):
def __init__(self, output_shape, latent_dim=32, activation='swish', output_activation='linear', layers=1, filters=4, kernel_size=3, multiple=1):
super(Decoder, self).__init__(name='Decoder')
input_size = output_shape[0] // 2**layers
self._input = Dense(input_size*input_size)
......@@ -35,7 +35,7 @@ class Decoder(Layer):
for i in range(layers)[::-1]]
self._convs = [MultiConvTranspose(filters=2**i*filters, kernel_size=kernel_size, activation=activation)
for i in range(layers)[::-1]]
self._out_layer = Conv2D(filters=output_shape[2], kernel_size=1, activation='linear')
self._out_layer = Conv2D(filters=output_shape[2], kernel_size=1, activation=output_activation)
def call(self, x, training=False):
x = self._input(x)
......
......@@ -23,7 +23,7 @@ def make_kumaraswamy_distr(cdim, reinterpreted_batch_ndims=1):
class VAE(Model):
def __init__(self, shape, n_styles, latent_dim=32, activation='swish', likelihood='sigma', log_beta=0.,
def __init__(self, shape, n_styles, latent_dim=32, activation='swish', output_activation='linear', likelihood='sigma', log_beta=0.,
encoder_layers=1, encoder_filters=4, encoder_kernel_size=3, encoder_multiple=1,
decoder_layers=2, decoder_filters=4, decoder_kernel_size=3, decoder_multiple=1,
feature_layers=3, feature_filters=4, feature_kernel_size=3, feature_multiple=1,
......@@ -56,7 +56,7 @@ class VAE(Model):
cdim = shape[-1]
param_shape = (*shape[:-1], 2*cdim)
self._connect = DistributionLambda(make_kumaraswamy_distr(cdim, reinterpreted_batch_ndims=len(shape)))
self._decoder = Decoder(param_shape, latent_dim=latent_dim+n_styles, activation=activation,
self._decoder = Decoder(param_shape, latent_dim=latent_dim+n_styles, activation=activation, output_activation=output_activation,
layers=decoder_layers, filters=decoder_filters, kernel_size=decoder_kernel_size, multiple=decoder_multiple)
def call(self, inputs, training=False):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment