Commit 62a87f78 authored by PhilFischer's avatar PhilFischer
Browse files

Reorder decoder layers

parent 2ff4bc5a
......@@ -28,23 +28,22 @@ class Encoder(Layer):
class Decoder(Layer):
def __init__(self, output_shape, latent_dim=32, activation='swish', layers=1, filters=4, kernel_size=3, multiple=1):
super(Decoder, self).__init__(name='Decoder')
self._reshape_in = Reshape([1, 1, latent_dim])
input_size = output_shape[0] // 2**layers
self._input = Dense(input_size*input_size)
self._reshape_in = Reshape([input_size, input_size, 1])
self._ups = [Conv2DTranspose(filters=2**i*filters, kernel_size=kernel_size+2, activation=activation, strides=2, padding='same')
for i in range(layers)[::-1]]
self._convs = [MultiConvTranspose(filters=2**i*filters, kernel_size=kernel_size, activation=activation)
for i in range(layers)[::-1]]
self._flatten = Flatten()
self._out_layer = Dense(output_shape[0]*output_shape[1]*output_shape[2], activation='linear')
self._reshape_out = Reshape(output_shape)
self._out_layer = Conv2D(filters=output_shape[2], kernel_size=1, activation='linear')
def call(self, x, training=False):
x = self._input(x)
x = self._reshape_in(x)
for conv, up in zip(self._convs, self._ups):
x = up(x)
x = conv(x, training=training)
x = self._flatten(x)
x = self._out_layer(x)
x = self._reshape_out(x)
return x
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment