Commit fd99e767 authored by nstorni's avatar nstorni
Browse files

Implemented support for 2 channel training.

parent 088d3e8d
{
"exp_name" : "lwtvae_l3",
"TrainingLevel":3,
"vaesamples" : 0,
"interpolation":"False",
"interpolation_steps":6,
"exp_name" : "laVAE_latent32",
"TrainingLevel":5,
"vaesamples" : 4,
"interpolation":"True",
"interpolation_steps":22,
"interpolations_dir":"$HOME/data/interpolations/RGBA",
"trainFull":"False",
"trainFull":"True",
"replaceReparametrization":"False",
"learning_rate" : 0.0003,
"batch_size" : 256,
"reload" : "True",
"reload_dir": "$SCRATCH/experiments/shvae_temporalmask/models/lwtvae_l2_again_D20200111T235430",
"epochs" : 400,
"reload" : "False",
"reload_dir": "$SCRATCH/experiments/",
"epochs" : 500,
"loss_log_freq":2,
"img_log_freq":50,
"betavae":1,
"input_dim": 256,
"latent_dim": 64,
"early_stopping": "True",
"weight_decay": "True",
"latent_dim": 32,
"early_stopping": "False",
"weight_decay": "False",
"normalize":"False",
"input_ch":4,
"logdir" : "$SCRATCH/experiments/shvae_temporalmask",
"input_ch":2,
"logdir" : "$SCRATCH/experiments/laVAE_latents",
"dataset_dir":"$SCRATCH/data/mice_tempdiff_medium",
"train_dataset_dir": "$SCRATCH/data/mice_tempdiff_medium/train",
"val_dataset_dir":"$SCRATCH/data/mice_tempdiff_medium/val",
......
......@@ -12,22 +12,16 @@ from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
from utils.misc import save_checkpoint
from utils.misc import LSIZE, RED_SIZE
## WARNING : THIS SHOULD BE REPLACE WITH PYTORCH 0.5
from utils.learning import EarlyStopping
from utils.learning import ReduceLROnPlateau
from data.loaders import RolloutObservationDataset
from utils.config import get_config_from_json, get_dataset_statistics
from PIL import Image
from datetime import datetime
import json
......@@ -76,8 +70,17 @@ def train(epoch):
batch_idx += 1
data = inputs.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar, config.TrainingLevel)
recon_batch, mu, logsigma = model(data)
# loss = loss_function(recon_batch, data, mu, logvar, config.TrainingLevel)
reconstructionLoss = F.mse_loss(recon_batch, data, reduction="sum")
loss = reconstructionLoss
writer.add_scalar('ReconstructionLoss/train', reconstructionLoss.item() / len(data), (epoch-1)*len(dataset_train) + batch_idx*len(data))
if config.TrainingLevel == 5 or config.trainFull == "True":
latentLoss = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
writer.add_scalar('LatentLoss/train', latentLoss.item() / len(data), (epoch-1)*len(dataset_train) + batch_idx*len(data))
loss = reconstructionLoss + config.betavae*latentLoss
loss.backward()
train_loss += loss.item()
optimizer.step()
......@@ -88,8 +91,8 @@ def train(epoch):
loss.item() / len(data)))
writer.add_scalar('Loss/train', loss.item() / len(data), (epoch-1)*len(dataset_train) + batch_idx*len(data))
if batch_idx % config.img_log_freq == 0:
# Separate RGBA handling
if inputs.shape[1] == 4:
# Separate handling for different numbers of channels
if inputs.shape[1] == 4: # RGBA
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3]).cpu()),0)
......@@ -102,12 +105,30 @@ def train(epoch):
alpha = torch.cat((alpha,temp),0)
writer.add_images('Train/alpha',alpha.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train) + batch_idx*len(data))
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3]*((inputs[0][3]>0.0)*1.0)).cpu()),0)
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3].cpu()*((inputs[0][3]>0.0)*1.0)).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3]*((inputs[i][3]>0.0)*1.0)).cpu()),0)
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3].cpu()*((inputs[i][3]>0.0)*1.0)).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train/masked',RGB, (epoch-1)*len(dataset_train) + batch_idx*len(data))
else:
if inputs.shape[1] == 2: # GrayscaleAlpha
RGB = torch.stack((inputs[0][0],(recon_batch[0][0]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0],(recon_batch[i][0]).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train',RGB.view(16,1,RGB.shape[1],RGB.shape[1]), (epoch-1)*len(dataset_train) + batch_idx*len(data))
# Log alpha channel separately
alpha = torch.stack((inputs[0][1],(recon_batch[0][1]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][1],(recon_batch[i][1]).cpu()),0)
alpha = torch.cat((alpha,temp),0)
writer.add_images('Train/alpha',alpha.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train) + batch_idx*len(data))
RGB = torch.stack((inputs[0][0],(recon_batch[0][0].cpu()*((inputs[0][1]>0.0)*1.0)).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0],(recon_batch[i][0].cpu()*((inputs[i][1]>0.0)*1.0)).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train/masked',RGB.view(16,1,RGB.shape[1],RGB.shape[1]), (epoch-1)*len(dataset_train) + batch_idx*len(data))
if inputs.shape[1] == 3: # RGB
RGB = torch.stack((inputs[0],(recon_batch[0]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i],(recon_batch[i]).cpu()),0)
......@@ -129,11 +150,19 @@ def test(epoch):
# for data in test_loader:
for inputs, labels in test_loader:
data = inputs.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar,False).item()
recon_batch, mu, logsigma = model(data)
# test_loss += loss_function(recon_batch, data, mu, logvar,False).item()
reconstructionLoss = F.mse_loss(recon_batch, data, reduction="sum")
if config.TrainingLevel == 5 or config.trainFull == "True":
latentLoss = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
test_loss += reconstructionLoss + config.betavae*latentLoss
else:
test_loss += reconstructionLoss
if logtestimgs:
logtestimgs = False # Just log one set of images per test.
if inputs.shape[1] == 4:
if inputs.shape[1] == 4: # RGBA
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3]).cpu()),0)
......@@ -145,7 +174,31 @@ def test(epoch):
temp = torch.stack((inputs[i][3],(recon_batch[i][3]).cpu()),0)
alpha = torch.cat((alpha,temp),0)
writer.add_images('Test/alpha',alpha.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train))
else:
RGB = torch.stack((inputs[0][0:3],(recon_batch[0][0:3].cpu()*((inputs[0][3]>0.0)*1.0)).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0:3],(recon_batch[i][0:3].cpu()*((inputs[i][3]>0.0)*1.0)).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train/masked',RGB, (epoch-1)*len(dataset_train))
if inputs.shape[1] == 2: # GrayscaleAlpha
RGB = torch.stack((inputs[0][0],(recon_batch[0][0]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0],(recon_batch[i][0]).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Test',RGB.view(16,1,RGB.shape[1],RGB.shape[1]), (epoch-1)*len(dataset_train))
# Log alpha channel separately
alpha = torch.stack((inputs[0][1],(recon_batch[0][1]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][1],(recon_batch[i][1]).cpu()),0)
alpha = torch.cat((alpha,temp),0)
writer.add_images('Test/alpha',alpha.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train))
RGB = torch.stack((inputs[0][0],(recon_batch[0][0].cpu()*((inputs[0][1]>0.0)*1.0)).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i][0],(recon_batch[i][0].cpu()*((inputs[i][1]>0.0)*1.0)).cpu()),0)
RGB = torch.cat((RGB,temp),0)
writer.add_images('Train/masked',RGB.view(16,1,alpha.shape[1],alpha.shape[1]), (epoch-1)*len(dataset_train))
if inputs.shape[1] == 3: # RGB
RGB = torch.stack((inputs[0],(recon_batch[0]).cpu()),0)
for i in range(1,8):
temp = torch.stack((inputs[i],(recon_batch[i]).cpu()),0)
......@@ -166,6 +219,11 @@ def rgba_loader(path):
with open(path, 'rb') as f:
img = Image.open(f).convert("RGBA")
return img
# Custom loader for grayscale images with alpha channel
def la_loader(path):
with open(path, 'rb') as f:
img = Image.open(f).convert("LA")
return img
def rgb_loader(path):
with open(path, 'rb') as f:
......@@ -178,20 +236,20 @@ def rgb_loader(path):
############ Training ############
##################################
# Parse training configuration file, specified in modelconfig argument.
parser = argparse.ArgumentParser(description='VAE Trainer')
parser.add_argument('--modelconfig', type=str, help='File path of training configuration')
args = parser.parse_args()
config, config_json = get_config_from_json(args.modelconfig)
# Enable cuda
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
# Dataloader
# Initialize transforms for dataloader.
if config.normalize == "False":
transforms = transforms.Compose([
transforms.Resize((config.input_dim,config.input_dim)),
......@@ -205,20 +263,27 @@ else:
transforms.ToTensor(),
transforms.Normalize(dataset_stats.means, dataset_stats.stds)
])
# Load correct dataset depending on image type.
if config.input_ch == 1:
print("Grayscale images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms,loader=grayscale_loader)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms,loader=grayscale_loader)
if config.input_ch == 4:
print("RGBA images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms,loader=rgba_loader)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms,loader=rgba_loader)
if config.input_ch == 2:
print("LA images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms,loader=la_loader)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms,loader=la_loader)
if config.input_ch == 3:
print("RGB images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms)
if config.input_ch == 4:
print("RGBA images")
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transforms,loader=rgba_loader)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transforms,loader=rgba_loader)
# Initialize Dataloader.
train_loader = torch.utils.data.DataLoader(
dataset_train, batch_size=config.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(
......@@ -297,16 +362,53 @@ if config.reload == "True" and exists(reload_file):
model.encoder.TrainingLevel = 5
else:
model = freezeLayers(model,config.TrainingLevel)
# Reload optimizer to update layers that need the gradient.
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
elif config.reload == "True" and not exists(reload_file):
raise Exception('Reload file not found: {}'.format(reload_file))
cur_best = None
writer.add_text("Configs", json.dumps(config_json, indent=4, sort_keys=True))
# Train given number of epochs
for epoch in range(1, config.epochs + 1):
train_loss = train(epoch)
test_loss = test(epoch)
if config.early_stopping == "True":
earlystopping.step(test_loss)
if config.weight_decay == "True":
scheduler.step(test_loss)
# Decode a set of random latent vectors.
if config.vaesamples > 0 and config.TrainingLevel == 5:
with torch.no_grad():
sample = torch.randn(config.vaesamples, config.latent_dim).to(device)
sample = model.decoder(sample).cpu()
if sample.shape[1] == 4: # RGBA
RGB = (sample[0][0:3]).view(1,3,256,256)
for i in range(1,config.vaesamples-1):
RGB = torch.cat((RGB,(sample[i][0:3]).view(1,3,256,256)),0)
writer.add_images('Samples',RGB, (epoch)*len(dataset_train))
# Log alpha channel separately
alpha = (sample[0][3]).view(1,256,256)
for i in range(1,config.vaesamples-1):
alpha = torch.cat((alpha,(sample[i+1][3]).view(1,256,256)),0)
writer.add_images('Samples/alpha',alpha.view(-1,1,256,256), (epoch)*len(dataset_train))
if sample.shape[1] == 2: # GrayscaleAlpha
RGB = (sample[0][0]).view(1,1,256,256)
for i in range(1,config.vaesamples-1):
RGB = torch.cat((RGB,(sample[i][0]).view(1,1,256,256)),0)
writer.add_images('Samples',RGB, (epoch)*len(dataset_train))
# Log alpha channel separately
alpha = (sample[0][1]).view(1,256,256)
for i in range(1,config.vaesamples-1):
alpha = torch.cat((alpha,(sample[i+1][1]).view(1,256,256)),0)
writer.add_images('Samples/alpha',alpha.view(-1,1,256,256), (epoch)*len(dataset_train))
if sample.shape[1] == 3: # RGB
writer.add_images('Samples',sample, (epoch)*len(dataset_train))
# Interpolate latent space between given pictures.
# Loads two pictures from a subdirectory of config.interpolations_dir.
if config.interpolation == "True":
......@@ -323,11 +425,15 @@ for epoch in range(1, config.epochs + 1):
if config.input_ch == 3:
img1 = transforms(rgb_loader(join(interpolation_dir,img_filenames[0])))
img2 = transforms(rgb_loader(join(interpolation_dir,img_filenames[1])))
if config.input_ch == 2:
img1 = transforms(la_loader(join(interpolation_dir,img_filenames[0])))
img2 = transforms(la_loader(join(interpolation_dir,img_filenames[1])))
if config.input_ch == 1:
img1 = transforms(grayscale_loader(join(interpolation_dir,img_filenames[0])))
img2 = transforms(grayscale_loader(join(interpolation_dir,img_filenames[1])))
# Get latent representation with model encoder.
batch = torch.stack((img1,img2))
batch = batch.to(device)
mu, logsigma = model.encoder(batch)
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
......@@ -342,25 +448,26 @@ for epoch in range(1, config.epochs + 1):
# Interpolate between the two latents and reconstruct latent to image with decoder.
for i in range(config.interpolation_steps):
inter_z += diff/config.interpolation_steps
recon_x = model.decoder(inter_z)
print(recon_x.shape)
temp = inter_z.to(device)
recon_x = model.decoder(temp)
interpolation = torch.cat((interpolation,recon_x.view(1,-1,256,256)),0)
# Add second image.
print(interpolation.shape)
interpolation = torch.cat((interpolation,batch[1].view(1,-1,256,256)),0)
# Add to tensorboard.
if interpolation.shape[1] == 4:
if interpolation.shape[1] == 4: # RGBA
writer.add_images('Interpolation/alpha',interpolation[:,3].view(-1,1,256,256), (epoch)*len(dataset_train))
writer.add_images('Interpolation',interpolation[:,0:3].view(-1,3,256,256), (epoch)*len(dataset_train))
print("done")
else:
if interpolation.shape[1] == 2: # GrayscaleAlpha
writer.add_images('Interpolation/alpha',interpolation[:,1].view(-1,1,256,256), (epoch)*len(dataset_train))
writer.add_images('Interpolation',interpolation[:,0].view(-1,3,256,256), (epoch)*len(dataset_train))
if interpolation.shape[1] == 3: # RGB
writer.add_images('Interpolation',interpolation, (epoch)*len(dataset_train))
train_loss = train(epoch)
test_loss = test(epoch)
#scheduler.step(test_loss)
#earlystopping.step(test_loss)
# checkpointing
best_filename = join(vae_dir, 'best.tar')
......@@ -378,24 +485,8 @@ for epoch in range(1, config.epochs + 1):
'earlystopping': earlystopping.state_dict()
}, is_best, filename, best_filename)
# Decode a set of random latent vectors.
if config.vaesamples > 0 and config.TrainingLevel == 5:
with torch.no_grad():
sample = torch.randn(config.vaesamples, config.latent_dim).to(device)
sample = model.decoder(sample).cpu()
if sample.shape[1] == 4:
RGB = (sample[0][0:3]).view(1,3,256,256)
for i in range(1,config.vaesamples-1):
RGB = torch.cat((RGB,(sample[i][0:3]).view(1,3,256,256)),0)
writer.add_images('Samples',RGB, (epoch)*len(dataset_train))
# Log alpha channel separately
alpha = (sample[0][3]).view(1,256,256)
for i in range(1,config.vaesamples-1):
alpha = torch.cat((alpha,(sample[i+1][3]).view(1,256,256)),0)
writer.add_images('Samples/alpha',alpha.view(config.vaesamples,1,alpha.shape[1],alpha.shape[1]), (epoch)*len(dataset_train))
else:
writer.add_images('Samples',sample, (epoch)*len(dataset_train))
if earlystopping.stop:
print("End of Training because of early stopping at epoch {}".format(epoch))
break
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment