Commit e5821656 authored by nstorni's avatar nstorni
Browse files

Layer Wise Training implementation of VAE

parent cb020edf
"""
Variational encoder model, used as a visual model
for our model of the world.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
""" VAE decoder """
def __init__(self, img_channels, latent_size):
super(Decoder, self).__init__()
self.latent_size = latent_size
self.img_channels = img_channels
self.TrainingLevel = 1
self.fc1 = nn.Linear(latent_size, 50176)
self.deconv1 = nn.ConvTranspose2d(256, 128, 4, stride=2)
self.deconv2 = nn.ConvTranspose2d(128, 64, 4, stride=2)
self.deconv3 = nn.ConvTranspose2d(64, 32, 5, stride=2)
self.deconv4 = nn.ConvTranspose2d(32, img_channels, 4, stride=2)
def forward(self, x): # pylint: disable=arguments-differ
if self.TrainingLevel > 4:
x = F.relu(self.fc1(x))
# x = x.unsqueeze(-1).unsqueeze(-1)
x = x.view(-1,256,14,14)
if self.TrainingLevel > 3:
x = F.relu(self.deconv1(x))
if self.TrainingLevel > 2:
x = F.relu(self.deconv2(x))
if self.TrainingLevel > 1:
x = F.relu(self.deconv3(x))
reconstruction = self.deconv4(x)
return reconstruction
class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
""" VAE encoder """
def __init__(self, img_channels, latent_size):
super(Encoder, self).__init__()
self.latent_size = latent_size
#self.img_size = img_size
self.TrainingLevel = 1
self.img_channels = img_channels
self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
self.fc_mu = nn.Linear(50176, latent_size)
self.fc_logsigma = nn.Linear(50176, latent_size)
def forward(self, x):
x = F.relu(self.conv1(x))
if self.TrainingLevel > 1:
x = F.relu(self.conv2(x))
if self.TrainingLevel > 2:
x = F.relu(self.conv3(x))
if self.TrainingLevel > 3:
x = F.relu(self.conv4(x))
if self.TrainingLevel > 4:
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logsigma = self.fc_logsigma(x)
return mu, logsigma
return x
class VAE(nn.Module):
""" Variational Autoencoder """
def __init__(self, img_channels, latent_size):
super(VAE, self).__init__()
self.TrainingLevel = 5
self.encoder = Encoder(img_channels, latent_size)
self.encoder.TrainingLevel = self.TrainingLevel
self.decoder = Decoder(img_channels, latent_size)
self.decoder.TrainingLevel = self.TrainingLevel
def forward(self, x):
if self.TrainingLevel == 5:
mu, logsigma = self.encoder(x)
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
recon_x = self.decoder(z)
return recon_x, mu, logsigma
# Just some data to return w
mu = 1.0
logsigma = 1.0
z = self.encoder(x)
recon_x = self.decoder(z)
return recon_x, mu, logsigma
{
"exp_name" : "testing",
"vaesamples" : 8,
"interpolation":"True",
"exp_name" : "lwtvae_l3",
"TrainingLevel":3,
"vaesamples" : 0,
"interpolation":"False",
"interpolation_steps":6,
"interpolations_dir":"$HOME/data/interpolations/RGBA",
"trainFull":"True",
"TrainingLevel":5,
"trainFull":"False",
"replaceReparametrization":"False",
"learning_rate" : 0.0003,
"batch_size" : 20,
"batch_size" : 256,
"reload" : "True",
"reload_dir": "$SCRATCH/experiments/mice_shvae/models/micetd_shvae_l5_continued_D20200110T063425",
"reload_dir": "$SCRATCH/experiments/shvae_temporalmask/models/lwtvae_l2_again_D20200111T235430",
"epochs" : 400,
"epochsamples": 1,
"loss_log_freq":2,
"img_log_freq":1,
"img_log_freq":50,
"betavae":1,
"input_dim": 256,
"latent_dim": 64,
......@@ -22,7 +21,7 @@
"weight_decay": "True",
"normalize":"False",
"input_ch":4,
"logdir" : "$SCRATCH/experiments/testing",
"logdir" : "$SCRATCH/experiments/shvae_temporalmask",
"dataset_dir":"$SCRATCH/data/mice_tempdiff_medium",
"train_dataset_dir": "$SCRATCH/data/mice_tempdiff_medium/train",
"val_dataset_dir":"$SCRATCH/data/mice_tempdiff_medium/val",
......
""" Training VAE """
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
from utils.misc import save_checkpoint
from utils.misc import LSIZE, RED_SIZE
## WARNING : THIS SHOULD BE REPLACE WITH PYTORCH 0.5
from utils.learning import EarlyStopping
from utils.learning import ReduceLROnPlateau
from data.loaders import RolloutObservationDataset
from utils.config import get_config_from_json, get_dataset_statistics
from PIL import Image
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logsigma, TrainingLevel):
""" VAE loss function """
BCE = F.mse_loss(recon_x, x, reduction="sum")
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
if TrainingLevel==5:
KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
return BCE + config.betavae*KLD
return BCE
def freezeLayers(model,TrainingLevel):
model.encoder.conv1.requires_grad = True if TrainingLevel==1 else False
model.encoder.conv2.requires_grad = True if TrainingLevel==2 else False
model.encoder.conv3.requires_grad = True if TrainingLevel==3 else False
model.encoder.conv4.requires_grad = True if TrainingLevel==4 else False
model.encoder.fc_mu.requires_grad = True if TrainingLevel==5 else False
model.encoder.fc_logsigma.requires_grad = True if TrainingLevel==5 else False
model.decoder.deconv4.requires_grad = True if TrainingLevel==1 else False
model.decoder.deconv3.requires_grad = True if TrainingLevel==2 else False
model.decoder.deconv2.requires_grad = True if TrainingLevel==3 else False
model.decoder.deconv1.requires_grad = True if TrainingLevel==4 else False
model.decoder.fc1.requires_grad = True if TrainingLevel==5 else False
model.encoder.TrainingLevel = TrainingLevel
model.decoder.TrainingLevel = TrainingLevel
model.TrainingLevel = TrainingLevel
return model
def train(epoch):
""" One training epoch """
model.train()
print(epoch)
train_loss = 0
batch_idx = 0
for inputs, labels in train_loader:
batch_idx += 1
data = inputs.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar, config.TrainingLevel)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % config.loss_log_freq == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
writer.add_scalar('Loss/train', loss.item() / len(data), (epoch-1)*len(dataset_train) + batch_idx*len(data))
if batch_idx % config.img_log_freq == 0:
# Separate RGBA handling
if inputs.shape[1] == 4:
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3]).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train',RGB, (epoch-1)*len(dataset_train) + batch_idx*len(data))
# Log alpha channel separately
alpha = torch.stack((inputs[0][3],(recon_batch[0][3]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][3],(recon_batch[i][3]).cpu()),0)
alpha = torch.cat((alpha,temp),0)
writer.add_images('Train/alpha',alpha.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train) + batch_idx*len(data))
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3]*((inputs[0][3]>0.0)*1.0)).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3]*((inputs[i][3]>0.0)*1.0)).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train/masked',RGB, (epoch-1)*len(dataset_train) + batch_idx*len(data))
else:
RGB = torch.stack((inputs[0],(recon_batch[0]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i],(recon_batch[i]).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train',RGB, (epoch-1)*len(dataset_train) + batch_idx*len(data))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
return train_loss / len(train_loader.dataset)
def test(epoch):
""" One test epoch """
model.eval()
test_loss = 0
logtestimgs = True
with torch.no_grad():
# for data in test_loader:
for inputs, labels in test_loader:
data = inputs.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar,False).item()
if logtestimgs:
logtestimgs = False # Just log one set of images per test.
if inputs.shape[1] == 4:
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3]).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Test',RGB, (epoch-1)*len(dataset_train))
# Log alpha channel separately
alpha = torch.stack((inputs[0][3],(recon_batch[0][3]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][3],(recon_batch[i][3]).cpu()),0)
alpha = torch.cat((alpha,temp),0)
writer.add_images('Test/alpha',alpha.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train))
else:
RGB = torch.stack((inputs[0],(recon_batch[0]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i],(recon_batch[i]).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Test',RGB, (epoch-1)*len(dataset_train))
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
writer.add_scalar('Loss/test', test_loss, (epoch)*len(dataset_train))
return test_loss
# Custom loader for grayscale images
def grayscale_loader(path):
with open(path, 'rb') as f:
img = Image.open(f).convert("L")
return img
# Custom loader for images with mask in the alpha channel (RGBA)
def rgba_loader(path):
with open(path, 'rb') as f:
img = Image.open(f).convert("RGBA")
return img
def rgb_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
##################################
############ Training ############
##################################
parser = argparse.ArgumentParser(description='VAE Trainer')
parser.add_argument('--modelconfig', type=str, help='File path of training configuration')
args = parser.parse_args()
config, config_json = get_config_from_json(args.modelconfig)
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
# Dataloader
if config.normalize == "False":
transforms = transforms.Compose([
transforms.Resize((config.input_dim,config.input_dim)),
transforms.ToTensor()
# transforms.Normalize(dataset_stats.means, dataset_stats.stds)
])
else:
dataset_stats = get_dataset_statistics(config.dataset_dir+"/statistics.json")
transforms = transforms.Compose([
transforms.Resize((config.input_dim,config.input_dim)),
transforms.ToTensor(),
transforms.Normalize(dataset_stats.means, dataset_stats.stds)
])
# Load correct dataset depending on image type.
if config.input_ch == 1:
print("Grayscale images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms,loader=grayscale_loader)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms,loader=grayscale_loader)
if config.input_ch == 4:
print("RGBA images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms,loader=rgba_loader)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms,loader=rgba_loader)
if config.input_ch == 3:
print("RGB images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms)
train_loader = torch.utils.data.DataLoader(
dataset_train, batch_size=config.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset_test, batch_size=config.batch_size, shuffle=True, num_workers=2)
# Load model
model = VAE(config.input_ch, config.latent_dim).to(device)
# Check if full model is trained or only some layer.
if config.trainFull == "True":
model.TrainingLevel = 5
model.decoder.TrainingLevel = 5
model.encoder.TrainingLevel = 5
else:
model = freezeLayers(model,config.TrainingLevel)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=5)
# Create folders for storing models and training logs.
if not exists(config.logdir):
mkdir(config.logdir)
if not exists(join(config.logdir, 'models')):
mkdir(join(config.logdir, 'models'))
# check vae dir exists, if not, create it
runs_dir = join(config.logdir, 'tensorboard_logs' )
if not exists(runs_dir):
mkdir(runs_dir)
# Stamp training experiment name with datetime stamp for unique naming.
datetimestamp = (datetime.now()).strftime("D%Y%m%dT%H%M%S")
vae_dir = join(config.logdir, 'models/' + config.exp_name + "_" + datetimestamp)
if not exists(vae_dir):
mkdir(vae_dir)
mkdir(join(vae_dir, 'samples'))
# Save configuration file
with open(join(vae_dir,'config.json'), 'w') as outfile:
json.dump(config_json, outfile)
# Tensorboard writer function initialization.
writer = SummaryWriter(runs_dir+"/" + config.exp_name + "_" + datetimestamp)
# If required load model weights from file.
reload_file = join(os.path.expandvars(config.reload_dir), 'best.tar')
if config.reload == "True" and exists(reload_file):
state = torch.load(reload_file, map_location=torch.device('cpu'))
print("Reloading model at epoch {}"
", with test error {}".format(
state['epoch'],
state['precision']))
# ReplaceReparametrizations allows to load model weights from a model with a different latent dimension.
if config.replaceReparametrization == "True":
model_dict = model.state_dict()
state['state_dict'].pop("encoder.fc_logsigma.bias")
state['state_dict'].pop("encoder.fc_logsigma.weight")
state['state_dict'].pop("encoder.fc_mu.weight")
state['state_dict'].pop("encoder.fc_mu.bias")
state['state_dict'].pop("decoder.fc1.weight")
state['state_dict'].pop("decoder.fc1.bias")
model_dict.update(state['state_dict'])
model.load_state_dict(model_dict)
else:
model.load_state_dict(state['state_dict'])
# optimizer.load_state_dict(state['optimizer'])
# scheduler.load_state_dict(state['scheduler'])
# earlystopping.load_state_dict(state['earlystopping'])
# Check if full model is trained or only some layer.
if config.trainFull == "True":
model.TrainingLevel = 5
model.decoder.TrainingLevel = 5
model.encoder.TrainingLevel = 5
else:
model = freezeLayers(model,config.TrainingLevel)
# Reload optimizer to update layers that need the gradient.
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
elif config.reload == "True" and not exists(reload_file):
raise Exception('Reload file not found: {}'.format(reload_file))
cur_best = None
writer.add_text("Configs", json.dumps(config_json, indent=4, sort_keys=True))
# Train given number of epochs
for epoch in range(1, config.epochs + 1):
# Interpolate latent space between given pictures.
# Loads two pictures from a subdirectory of config.interpolations_dir.
if config.interpolation == "True":
with torch.no_grad():
# Load folders with interpolation images
for interpolation_name in os.listdir(os.path.expandvars(config.interpolations_dir)):
interpolation_dir = join(os.path.expandvars(config.interpolations_dir),interpolation_name)
if os.path.isdir(interpolation_dir):
img_filenames = os.listdir(join(os.path.expandvars(config.interpolations_dir),interpolation_name))
# Use correct loader depending on number of image channels.
if config.input_ch == 4:
img1 = transforms(rgba_loader(join(interpolation_dir,img_filenames[0])))
img2 = transforms(rgba_loader(join(interpolation_dir,img_filenames[1])))
if config.input_ch == 3:
img1 = transforms(rgb_loader(join(interpolation_dir,img_filenames[0])))
img2 = transforms(rgb_loader(join(interpolation_dir,img_filenames[1])))
if config.input_ch == 1:
img1 = transforms(grayscale_loader(join(interpolation_dir,img_filenames[0])))
img2 = transforms(grayscale_loader(join(interpolation_dir,img_filenames[1])))
# Get latent representation with model encoder.
batch = torch.stack((img1,img2))
mu, logsigma = model.encoder(batch)
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
# Compute difference between latent representations
diff = z[1]-z[0]
# Add first image.
interpolation = batch[0].view(1,-1,256,256)
inter_z = z[0]
# Interpolate between the two latents and reconstruct latent to image with decoder.
for i in range(config.interpolation_steps):
inter_z += diff/config.interpolation_steps
recon_x = model.decoder(inter_z)
print(recon_x.shape)
interpolation = torch.cat((interpolation,recon_x.view(1,-1,256,256)),0)
# Add second image.
print(interpolation.shape)
interpolation = torch.cat((interpolation,batch[1].view(1,-1,256,256)),0)
# Add to tensorboard.
if interpolation.shape[1] == 4:
writer.add_images('Interpolation/alpha',interpolation[:,3].view(-1,1,256,256), (epoch)*len(dataset_train))
writer.add_images('Interpolation',interpolation[:,0:3].view(-1,3,256,256), (epoch)*len(dataset_train))
print("done")
else:
writer.add_images('Interpolation',interpolation, (epoch)*len(dataset_train))
train_loss = train(epoch)
test_loss = test(epoch)
#scheduler.step(test_loss)
#earlystopping.step(test_loss)
# checkpointing
best_filename = join(vae_dir, 'best.tar')
filename = join(vae_dir, 'checkpoint.tar')
is_best = not cur_best or test_loss < cur_best
if is_best:
cur_best = test_loss
save_checkpoint({
'epoch': epoch,
'state_dict': model.state_dict(),
'precision': test_loss,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'earlystopping': earlystopping.state_dict()
}, is_best, filename, best_filename)
# Decode a set of random latent vectors.
if config.vaesamples > 0 and config.TrainingLevel == 5:
with torch.no_grad():
sample = torch.randn(config.vaesamples, config.latent_dim).to(device)
sample = model.decoder(sample).cpu()
if sample.shape[1] == 4:
RGB = (sample[0][0:3]).view(1,3,256,256)
for i in range(1,config.vaesamples-1):
RGB = torch.cat((RGB,(sample[i][0:3]).view(1,3,256,256)),0)
writer.add_images('Samples',RGB, (epoch)*len(dataset_train))
# Log alpha channel separately
alpha = (sample[0][3]).view(1,256,256)
for i in range(1,config.vaesamples-1):
alpha = torch.cat((alpha,(sample[i+1][3]).view(1,256,256)),0)
writer.add_images('Samples/alpha',alpha.view(config.vaesamples,1,alpha.shape[1],alpha.shape[1]), (epoch)*len(dataset_train))
else:
writer.add_images('Samples',sample, (epoch)*len(dataset_train))
if earlystopping.stop:
print("End of Training because of early stopping at epoch {}".format(epoch))
break
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment