Commit 088d3e8d authored by nstorni's avatar nstorni
Browse files

Messy code for generating latent vectors

parent e5821656
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
# from utils.misc import save_checkpoint
from PIL import Image
# Custom loader for images with mask in the alpha channel (RGBA)
def rgba_loader(path):
with open(path, 'rb') as f:
img = Image.open(f).convert("RGBA")
return img
transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
# Cuda magic
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64
# Load model
model = VAE(4, 64).to(device)
reload_file = join(os.path.expandvars("$SCRATCH/experiments/shvae_temporalmask/models/vae_D20200111T092802"), 'best.tar')
state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict'])
# Initiliaze tensor
tensor = torch.ones(())
latents = tensor.new_empty((n_videos, n_frames, 2,n_latents))
print(latents.shape)
count = 0
with torch.no_grad():
# Iterate train images
for imgfilename in os.listdir("train/nolabel"):
# Extract video number
tmp = os.path.splitext(imgfilename)[0].split("_")
video_n = int(tmp[0])
frame_n = int(tmp[2]) - 110
img = transforms(rgba_loader(join("train/nolabel",imgfilename)))
# Run throug model encoder
mu, logsigma = model.encoder(img.view(1,4,256,256).to(device))
del img
latents[video_n-1][frame_n][0] = mu
latents[video_n-1][frame_n][1] = logsigma
# print(latents[video_n-1][frame_n])
print("Processing frame {} of video number {}".format(video_n,frame_n))
count += 1
# Iterate val images
for imgfilename in os.listdir("val/nolabel"):
# Extract video number
tmp = os.path.splitext(imgfilename)[0].split("_")
video_n = int(tmp[0])
frame_n = int(tmp[2]) - 110
img = transforms(rgba_loader(join("val/nolabel",imgfilename)))
# Run throug model encoder
mu, logsigma = model.encoder(img.view(1,4,256,256).to(device))
latents[video_n-1][frame_n][0] = mu
latents[video_n-1][frame_n][1] = logsigma
del img
# print(latents[video_n-1][frame_n])
print("Processing frame {} of video number {}".format(video_n,frame_n))
count += 1
print(count)
# Save latents to file
torch.save(latents, 'vae_D20200111T092802.pt')
# Reload latent
# latents = torch.load('tensor.pt')
\ No newline at end of file
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
import numpy as np
# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64
# Initiliaze tensor
tensor = torch.ones(())
latents = tensor.new_empty((n_videos, n_frames, 2,n_latents))
print(latents.shape)
# A vector of size 15 with values from -5 to 5
a = np.linspace(0, 120*2*np.pi, 13000-110)
# Applying the sine function and
# storing the result in 'b'
for i in range(n_latents):
latents[:,:,0,i] = torch.sin(torch.FloatTensor(a))
latents[:,:,1,i] = -100000.0
print(latents[0,0:100,1,0])
# Compute z vector.
mu = latents[:,:,0]
logsigma = latents[:,:,1]
print(logsigma[1])
sigma = logsigma.exp()
print("sigma")
print(sigma[4])
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
print(z[0,0:200,0])
# Save latents to file
torch.save(latents, 'sin.pt')
# Reload latent
# latents = torch.load('tensor.pt')
\ No newline at end of file
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
# from utils.misc import save_checkpoint
from PIL import Image
import cv2
import numpy as np
# Cuda magic
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64
with torch.no_grad():
# Load model
model = VAE(4, 64).to(device)
reload_file = join(os.path.expandvars("$SCRATCH/experiments/shvae_temporalmask/models/vae_D20200111T092802"), 'best.tar')
state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict'])
# Reload latent
latents = torch.load('vae_D20200111T092802.pt')
print(latents.shape)
# Compute z vector.
mu = latents[:,:,0]
logsigma = latents[:,:,1]
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
# out = cv2.VideoWriter('test2.avi',cv2.VideoWriter_fourcc(*'DIVX'), 25, (256,256))
# out = cv2.VideoWriter('decoded.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25, (256,256))
tensor_2_image = transforms.ToPILImage()
img_array = []
count = 0
writer = SummaryWriter("/cluster/scratch/nstorni/experiments/testing/tensorboard_logs")
# for frame in z[0]:
# img = model.decoder(frame)
# print((img[0,0:3].numpy()))
# print(np.uint8(300*img[0,0:3].numpy()))
# # writer.add_image('Z',np.uint8(10*img[0,0:3].numpy()),0)
# # writer.close()
# # break
# count += 1
# print(count)
# out.write(np.uint8((200*img[0,0:3]).permute(1, 2, 0).numpy()))
# # if count > 200: break
# # img_array.append(img[0,0:3].numpy())
# # img_array.append(tensor_2_image(img[0,0:3]))
# # print(img_array.shape)
# # for i in range(len(img_array)):
# # out.write(img_array[i])
# out.release()
# Visualize latents evolution.
for i in range(9):
writer.add_image('Z1',z[i].view(1,12890,64),i)
writer.add_image('mu1',mu[i].view(1,12890,64),i)
writer.add_image('logsigma1',logsigma[i].view(1,12890,64),i)
writer.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment