generate_latents.py 3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os

import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
# from utils.misc import save_checkpoint
from PIL import Image

# Custom loader for images with mask in the alpha channel (RGBA)
def rgba_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f).convert("RGBA")
        return img

transforms  = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor()
            ])

# Cuda magic
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")



# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64

# Load model
model = VAE(4, 64).to(device)

reload_file = join(os.path.expandvars("$SCRATCH/experiments/shvae_temporalmask/models/vae_D20200111T092802"), 'best.tar')
state = torch.load(reload_file, map_location=torch.device('cpu'))

model.load_state_dict(state['state_dict'])

# Initiliaze tensor
tensor = torch.ones(())
latents = tensor.new_empty((n_videos, n_frames, 2,n_latents))
print(latents.shape)

count = 0
with torch.no_grad():
    # Iterate train images
    for imgfilename in os.listdir("train/nolabel"):
        
        # Extract video number
        tmp = os.path.splitext(imgfilename)[0].split("_")
        video_n = int(tmp[0])
        frame_n = int(tmp[2]) - 110
        img = transforms(rgba_loader(join("train/nolabel",imgfilename)))

        # Run throug model encoder
        mu, logsigma = model.encoder(img.view(1,4,256,256).to(device))
        del img
        latents[video_n-1][frame_n][0] = mu
        latents[video_n-1][frame_n][1] = logsigma
        # print(latents[video_n-1][frame_n])
        print("Processing frame {} of video number {}".format(video_n,frame_n))
        count += 1
    # Iterate val images
    for imgfilename in os.listdir("val/nolabel"):
        # Extract video number
        tmp = os.path.splitext(imgfilename)[0].split("_")
        video_n = int(tmp[0])
        frame_n = int(tmp[2]) - 110
        img = transforms(rgba_loader(join("val/nolabel",imgfilename)))

        # Run throug model encoder
        mu, logsigma = model.encoder(img.view(1,4,256,256).to(device))
        latents[video_n-1][frame_n][0] = mu
        latents[video_n-1][frame_n][1] = logsigma
        del img
        # print(latents[video_n-1][frame_n])
        print("Processing frame {} of video number {}".format(video_n,frame_n))
        count += 1
print(count)
# Save latents to file
torch.save(latents, 'vae_D20200111T092802.pt')
# Reload latent
# latents = torch.load('tensor.pt')