Commit 88ce5d78 authored by Rafael Dätwyler's avatar Rafael Dätwyler
Browse files

modified VAE to have five conv layers and use the image size of our mouse images

parent 29ec633f
......@@ -14,19 +14,22 @@ class Decoder(nn.Module):
self.latent_size = latent_size
self.img_channels = img_channels
self.fc1 = nn.Linear(latent_size, 1024)
self.deconv1 = nn.ConvTranspose2d(1024, 128, 5, stride=2)
self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
self.deconv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
self.deconv4 = nn.ConvTranspose2d(32, img_channels, 6, stride=2)
self.fc1 = nn.Linear(latent_size, 29*18*512)
self.devonv1 = nn.ConvTranspose2d(512, 256, 5, stride=2)
self.devonv2 = nn.ConvTranspose2d(256, 128, 5, stride=2)
self.deconv3 = nn.ConvTranspose2d(128, 64, 5, stride=2)
self.deconv4 = nn.ConvTranspose2d(64, 32, 6, stride=2)
self.deconv5 = nn.ConvTranspose2d(32, img_channels, 6, stride=2)
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.fc1(x))
x = x.unsqueeze(-1).unsqueeze(-1)
x = x.view(-1, 512, 18, 29)
#x = x.unsqueeze(-1).unsqueeze(-1)
x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = F.relu(self.deconv3(x))
reconstruction = F.sigmoid(self.deconv4(x))
x = F.relu(self.deconv4(x))
reconstruction = F.sigmoid(self.deconv5(x))
return reconstruction
class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
......@@ -37,13 +40,20 @@ class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
#self.img_size = img_size
self.img_channels = img_channels
# 928 x 576
self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
# 464 x 288
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
# 232 x 144
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
# 116 x 72
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
# 58 x 36
self.conv5 = nn.Conv2d(256, 512, 4, stride=2)
# 29 x 18
self.fc_mu = nn.Linear(2*2*256, latent_size)
self.fc_logsigma = nn.Linear(2*2*256, latent_size)
self.fc_mu = nn.Linear(29*18*512, latent_size)
self.fc_logsigma = nn.Linear(29*18*512, latent_size)
def forward(self, x): # pylint: disable=arguments-differ
......@@ -51,6 +61,7 @@ class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment