Commit f28e4f85 authored by nstorni's avatar nstorni
Browse files

VQVAE, training not updated

parent a6ffc631
import torch
from torch import nn
from torch.nn import functional as F
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
embed = torch.randn(dim, n_embed)
self.register_buffer('embed', embed)
self.register_buffer('cluster_size', torch.zeros(n_embed))
self.register_buffer('embed_avg', embed.clone())
def forward(self, input):
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.training:
self.cluster_size.data.mul_(self.decay).add_(
1 - self.decay, embed_onehot.sum(0)
)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
self.embed_avg.data.mul_(self.decay).add_(1 - self.decay, embed_sum)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
class ResBlock(nn.Module):
def __init__(self, in_channel, channel):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1),
)
def forward(self, input):
out = self.conv(input)
out += input
return out
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
channel // 2, out_channel, 4, stride=2, padding=1
),
]
)
elif stride == 2:
blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE(nn.Module):
def __init__(
self,
in_channel=1,
channel=128,
n_res_block=2,
n_res_channel=32,
embed_dim=64,
n_embed=512,
decay=0.99,
):
super().__init__()
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
self.quantize_t = Quantize(embed_dim, n_embed)
self.dec_t = Decoder(
embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
self.quantize_b = Quantize(embed_dim, n_embed)
self.upsample_t = nn.ConvTranspose2d(
embed_dim, embed_dim, 4, stride=2, padding=1
)
self.dec = Decoder(
embed_dim + embed_dim,
in_channel,
channel,
n_res_block,
n_res_channel,
stride=4,
)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
print(quant_t.shape)
print(quant_b.shape)
dec = self.decode(quant_t, quant_b)
return dec, diff
def encode(self, input):
enc_b = self.enc_b(input)
enc_t = self.enc_t(enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = self.dec_t(quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = self.dec(quant)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
\ No newline at end of file
""" Training VAE """
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.vqvae import VQVAE
from utils.misc import save_checkpoint
from utils.misc import LSIZE, RED_SIZE
## WARNING : THIS SHOULD BE REPLACE WITH PYTORCH 0.5
from utils.learning import EarlyStopping
from utils.learning import ReduceLROnPlateau
from utils.config import get_config_from_json, get_dataset_statistics
from PIL import Image
def pil_loader_custom(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f).convert("L")
return img
parser = argparse.ArgumentParser(description='VAE Trainer')
parser.add_argument('--modelconfig', type=str, help='File path of training configuration')
args = parser.parse_args()
config, config_json = get_config_from_json(args.modelconfig)
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
INPUT_WIDTH = config.input_dim[0]
INPUT_HEIGHT = config.input_dim[1]
dataset_stats = get_dataset_statistics(config.dataset_dir+"/statistics.json")
def pil_loader_custom(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f).convert("L")
return img
# transform_test = transforms.Compose([
# transforms.Resize((INPUT_HEIGHT, INPUT_WIDTH)),
# transforms.ToTensor(),
# transforms.Normalize(dataset_stats.means, dataset_stats.stds)
# ])
transform_test = transforms.Compose(
[
transforms.Resize(INPUT_WIDTH),
transforms.CenterCrop(INPUT_WIDTH),
transforms.ToTensor(),
transforms.Normalize(dataset_stats.means, dataset_stats.stds)
]
)
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transform_test,loader=pil_loader_custom)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transform_test,loader=pil_loader_custom)
train_loader = torch.utils.data.DataLoader(
dataset_train, batch_size=config.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset_test, batch_size=config.batch_size, shuffle=True, num_workers=2)
model = VQVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=5)
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x):
""" VAE loss function """
BCE = F.mse_loss(recon_x, x, reduction="sum")
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
return BCE
latent_loss_weight = 0.25
criterion = torch.nn.MSELoss()
def train(epoch):
""" One training epoch """
model.train()
# dataset_test.load_next_buffer()
train_loss = 0
# for batch_idx, data in enumerate(train_loader):
# data = data.to(device)
batch_idx = 0
for inputs, labels in train_loader:
# writer.add_image('input', inputs[0], batch_idx * len(inputs))
batch_idx += 1
data = inputs.to(device)
optimizer.zero_grad()
# writer.add_image('reconstruction', recon_batch[0], batch_idx * len(data))
# writer.add_images('train',torch.stack((inputs[0],recon_batch[0]),0), (epoch-1)*len(train_loader) + batch_idx*len(data))
out, latent_loss = model(data)
recon_loss = criterion(out, data)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 2 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
writer.add_images('Train',torch.stack((inputs[0],(out[0]).cpu()),0), (epoch-1)*len(dataset_train) + batch_idx*len(data))
writer.add_scalar('Loss/train', loss.item() / len(data), (epoch-1)*len(dataset_train) + batch_idx*len(data))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
return train_loss / len(train_loader.dataset)
def test(epoch):
""" One test epoch """
model.eval()
# dataset_test.load_next_buffer()
test_loss = 0
testlogscnt = 0
with torch.no_grad():
# for data in test_loader:
for inputs, labels in test_loader:
data = inputs.to(device)
recon_batch, latent_loss = model(data)
recon_loss = criterion(recon_batch, data)
latent_loss = latent_loss.mean()
test_loss = recon_loss + latent_loss_weight * latent_loss
if testlogscnt < config.testlogs:
writer.add_images('Test',torch.stack((inputs[0],(recon_batch[0][0:3]).cpu()),0), (epoch)*len(dataset_train))
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
writer.add_scalar('Loss/test', test_loss, (epoch)*len(dataset_train))
return test_loss
# Create folders for storing models and training logs.
if not exists(config.logdir):
mkdir(config.logdir)
if not exists(join(config.logdir, 'models')):
mkdir(join(config.logdir, 'models'))
# check vae dir exists, if not, create it
runs_dir = join(config.logdir, 'tensorboard_logs' )
if not exists(runs_dir):
mkdir(runs_dir)
# Stamp training experiment name with datetime stamp for unique naming.
datetimestamp = (datetime.now()).strftime("D%Y%m%dT%H%M%S")
vae_dir = join(config.logdir, 'models/' + config.exp_name + datetimestamp)
if not exists(vae_dir):
mkdir(vae_dir)
mkdir(join(vae_dir, 'samples'))
# Save configuration file
with open(join(vae_dir,'config.json'), 'w') as outfile:
json.dump(config_json, outfile)
writer = SummaryWriter(runs_dir+"/" + config.exp_name + "_" + datetimestamp)
reload_file = join(config.reload_dir, 'best.tar')
if not config.noreload and exists(reload_file):
state = torch.load(reload_file)
print("Reloading model at epoch {}"
", with test error {}".format(
state['epoch'],
state['precision']))
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])
earlystopping.load_state_dict(state['earlystopping'])
cur_best = None
for epoch in range(1, config.epochs + 1):
train_loss = train(epoch)
test_loss = test(epoch)
scheduler.step(test_loss)
earlystopping.step(test_loss)
# checkpointing
best_filename = join(vae_dir, 'best.tar')
filename = join(vae_dir, 'checkpoint.tar')
is_best = not cur_best or test_loss < cur_best
if is_best:
cur_best = test_loss
save_checkpoint({
'epoch': epoch,
'state_dict': model.state_dict(),
'precision': test_loss,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'earlystopping': earlystopping.state_dict()
}, is_best, filename, best_filename)
if config.nsamples > 0:
with torch.no_grad():
sample = torch.randn(config.nsamples, config.latent_dim).to(device)
sample = model.decoder(sample).cpu()
writer.add_images('Samples',sample, epoch)
# save_image(sample.view(1, 3, INPUT_HEIGHT, INPUT_WIDTH),
# join(vae_dir, 'samples/sample_' + str(epoch) + '.png'))
if earlystopping.stop:
print("End of Training because of early stopping at epoch {}".format(epoch))
break
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment