Commit eb1ae055 authored by nstorni's avatar nstorni
Browse files

Merge branch 'development' into 'lstm-adaptation'

# Conflicts:
#   README.md
parents 0cce655a f28e4f85
......@@ -61,6 +61,9 @@ module load gcc/4.8.5 python_cpu/3.7.1
python $HOME/world-models/utils/dataset_statistics.py --imagetype=RGB
```
In the cluster:
bsub -o $HOME/job_logs -n 4 -R "rusage[mem=4096 ]" "python $HOME/world-models/utils/dataset_statistics.py --imagetype=RGBA"
Move some images from train to test folder (not really random, just taking all the file ending with 1 in the filename) should be about 10% of the train set.
......@@ -81,7 +84,7 @@ cp world-models/template_config.json training_configs/train_config1.json
Modify train_config1.json for your training run, then you can test the configuration by starting the training on the local leonhard instance:
```bash
$HOME/world-models/start_training.sh --modeldir train_shvae.py --modelconfigdir $HOME/training_configs/train_config1.json
$HOME/world-models/start_training.sh --modeldir train_lwtvae.py --modelconfigdir $HOME/training_configs/train_config1.json
```
If there are no errors interrupt the training with CTRL+C, you can now submit it to the cluster.
......@@ -104,6 +107,11 @@ You can access tensorboard in your browser at localhost:6993
The training job will save the models, samples and logs in the $SCRATCH/experiments folder.
You can delete failed runs from the experiments directory with the following script, the first argument is the experiments subdirectory name and the second argument is the run name (read it from the tensorboard) :
```bash
$HOME/world-models/utils/delete_failed_runs.sh shvae_temporalmask lwtvae_l3_D20200112T183921
```
## 5. Generate Moving MNIST dataset
You can generate a custom toy dataset with a MNIST digits moving on a black frame that bounce on the borders with the following script:
```bash
......@@ -151,6 +159,39 @@ If you don't get any errors, you can abort the execution and send the job to be
bsub -o job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "$HOME/world-models/start_training.sh --modeldir trainlstm.py --modelconfigdir $HOME/training_configs/rnn_config1.json"
```
## 6. Generate interframe differences dataset
```bash
DATASET_NAME="interdiff_dataset"
mkdir $SCRATCH/data
mkdir $SCRATCH/data/$DATASET_NAME
mkdir $SCRATCH/data/$DATASET_NAME/video
mkdir $SCRATCH/data/$DATASET_NAME/train
mkdir $SCRATCH/data/$DATASET_NAME/train/nolabel
mkdir $SCRATCH/data/$DATASET_NAME/val
mkdir $SCRATCH/data/$DATASET_NAME/val/nolabel
mkdir $SCRATCH/data/$DATASET_NAME/test
mkdir $SCRATCH/data/$DATASET_NAME/test/nolabel
```
This structure is required for the dataset loader to load each picture (it expects a folder containing subfolders for each class, in our case we have only one class "nolabel").
Load one video to $SCRATCH/data/$DATASET_NAME/video using winSCP or similar.
E.g copy from local (ON YOUR LOCAL MACHINE) directory video.avi to the target directory in the cloud.:
```bash
scp *.mp4 nstorni@login.leonhard.ethz.ch:$SCRATCH/data/mini_mice_dataset/video
```
Generate interframe differences:
```bash
cd $SCRATCH/data/$DATASET_NAME
module purge
module load gcc/4.8.5 python_gpu/3.7.1
python $HOME/world-models/video_frame/gen_tempdiff_data.py
```
# Pytorch implementation of the "WorldModels"
Paper: Ha and Schmidhuber, "World Models", 2018. https://doi.org/10.5281/zenodo.1207631. For a quick summary of the paper and some additional experiments, visit the [github page](https://ctallec.github.io/world-models/).
......
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
# from utils.misc import save_checkpoint
from PIL import Image
# Custom loader for images with mask in the alpha channel (RGBA)
def rgba_loader(path):
with open(path, 'rb') as f:
img = Image.open(f).convert("RGBA")
return img
transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
# Cuda magic
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64
# Load model
model = VAE(4, 64).to(device)
reload_file = join(os.path.expandvars("$SCRATCH/experiments/shvae_temporalmask/models/vae_D20200111T092802"), 'best.tar')
state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict'])
# Initiliaze tensor
tensor = torch.ones(())
latents = tensor.new_empty((n_videos, n_frames, 2,n_latents))
print(latents.shape)
count = 0
with torch.no_grad():
# Iterate train images
for imgfilename in os.listdir("train/nolabel"):
# Extract video number
tmp = os.path.splitext(imgfilename)[0].split("_")
video_n = int(tmp[0])
frame_n = int(tmp[2]) - 110
img = transforms(rgba_loader(join("train/nolabel",imgfilename)))
# Run throug model encoder
mu, logsigma = model.encoder(img.view(1,4,256,256).to(device))
del img
latents[video_n-1][frame_n][0] = mu
latents[video_n-1][frame_n][1] = logsigma
# print(latents[video_n-1][frame_n])
print("Processing frame {} of video number {}".format(video_n,frame_n))
count += 1
# Iterate val images
for imgfilename in os.listdir("val/nolabel"):
# Extract video number
tmp = os.path.splitext(imgfilename)[0].split("_")
video_n = int(tmp[0])
frame_n = int(tmp[2]) - 110
img = transforms(rgba_loader(join("val/nolabel",imgfilename)))
# Run throug model encoder
mu, logsigma = model.encoder(img.view(1,4,256,256).to(device))
latents[video_n-1][frame_n][0] = mu
latents[video_n-1][frame_n][1] = logsigma
del img
# print(latents[video_n-1][frame_n])
print("Processing frame {} of video number {}".format(video_n,frame_n))
count += 1
print(count)
# Save latents to file
torch.save(latents, 'vae_D20200111T092802.pt')
# Reload latent
# latents = torch.load('tensor.pt')
\ No newline at end of file
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
import numpy as np
# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64
# Initiliaze tensor
tensor = torch.ones(())
latents = tensor.new_empty((n_videos, n_frames, 2,n_latents))
print(latents.shape)
# A vector of size 15 with values from -5 to 5
a = np.linspace(0, 120*2*np.pi, 13000-110)
# Applying the sine function and
# storing the result in 'b'
for i in range(n_latents):
latents[:,:,0,i] = torch.sin(torch.FloatTensor(a))
latents[:,:,1,i] = -100000.0
print(latents[0,0:100,1,0])
# Compute z vector.
mu = latents[:,:,0]
logsigma = latents[:,:,1]
print(logsigma[1])
sigma = logsigma.exp()
print("sigma")
print(sigma[4])
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
print(z[0,0:200,0])
# Save latents to file
torch.save(latents, 'sin.pt')
# Reload latent
# latents = torch.load('tensor.pt')
\ No newline at end of file
""" Training VAE """
# bsub -o $HOME/job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "python $HOME/world-models/generate_latents.py"
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
# from utils.misc import save_checkpoint
from PIL import Image
import cv2
import numpy as np
# Cuda magic
cuda = torch.cuda.is_available()
torch.manual_seed(123)
# Fix numeric divergence due to bug in Cudnn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
# Parse arguments
n_videos = 9
n_frames = 13000-110
n_latents = 64
with torch.no_grad():
# Load model
model = VAE(4, 64).to(device)
reload_file = join(os.path.expandvars("$SCRATCH/experiments/shvae_temporalmask/models/vae_D20200111T092802"), 'best.tar')
state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict'])
# Reload latent
latents = torch.load('vae_D20200111T092802.pt')
print(latents.shape)
# Compute z vector.
mu = latents[:,:,0]
logsigma = latents[:,:,1]
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
# out = cv2.VideoWriter('test2.avi',cv2.VideoWriter_fourcc(*'DIVX'), 25, (256,256))
# out = cv2.VideoWriter('decoded.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25, (256,256))
tensor_2_image = transforms.ToPILImage()
img_array = []
count = 0
writer = SummaryWriter("/cluster/scratch/nstorni/experiments/testing/tensorboard_logs")
# for frame in z[0]:
# img = model.decoder(frame)
# print((img[0,0:3].numpy()))
# print(np.uint8(300*img[0,0:3].numpy()))
# # writer.add_image('Z',np.uint8(10*img[0,0:3].numpy()),0)
# # writer.close()
# # break
# count += 1
# print(count)
# out.write(np.uint8((200*img[0,0:3]).permute(1, 2, 0).numpy()))
# # if count > 200: break
# # img_array.append(img[0,0:3].numpy())
# # img_array.append(tensor_2_image(img[0,0:3]))
# # print(img_array.shape)
# # for i in range(len(img_array)):
# # out.write(img_array[i])
# out.release()
# Visualize latents evolution.
for i in range(9):
writer.add_image('Z1',z[i].view(1,12890,64),i)
writer.add_image('mu1',mu[i].view(1,12890,64),i)
writer.add_image('logsigma1',logsigma[i].view(1,12890,64),i)
writer.close()
import torch
from torch import nn
from torch.nn import functional as F
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
embed = torch.randn(dim, n_embed)
self.register_buffer('embed', embed)
self.register_buffer('cluster_size', torch.zeros(n_embed))
self.register_buffer('embed_avg', embed.clone())
def forward(self, input):
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.training:
self.cluster_size.data.mul_(self.decay).add_(
1 - self.decay, embed_onehot.sum(0)
)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
self.embed_avg.data.mul_(self.decay).add_(1 - self.decay, embed_sum)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
class ResBlock(nn.Module):
def __init__(self, in_channel, channel):
super().__init__()
self.conv = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, channel, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, in_channel, 1),
)
def forward(self, input):
out = self.conv(input)
out += input
return out
class Encoder(nn.Module):
def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
super().__init__()
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class Decoder(nn.Module):
def __init__(
self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
):
super().__init__()
blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
for i in range(n_res_block):
blocks.append(ResBlock(channel, n_res_channel))
blocks.append(nn.ReLU(inplace=True))
if stride == 4:
blocks.extend(
[
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
channel // 2, out_channel, 4, stride=2, padding=1
),
]
)
elif stride == 2:
blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, input):
return self.blocks(input)
class VQVAE(nn.Module):
def __init__(
self,
in_channel=1,
channel=128,
n_res_block=2,
n_res_channel=32,
embed_dim=64,
n_embed=512,
decay=0.99,
):
super().__init__()
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
self.quantize_t = Quantize(embed_dim, n_embed)
self.dec_t = Decoder(
embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
self.quantize_b = Quantize(embed_dim, n_embed)
self.upsample_t = nn.ConvTranspose2d(
embed_dim, embed_dim, 4, stride=2, padding=1
)
self.dec = Decoder(
embed_dim + embed_dim,
in_channel,
channel,
n_res_block,
n_res_channel,
stride=4,
)
def forward(self, input):
quant_t, quant_b, diff, _, _ = self.encode(input)
print(quant_t.shape)
print(quant_b.shape)
dec = self.decode(quant_t, quant_b)
return dec, diff
def encode(self, input):
enc_b = self.enc_b(input)
enc_t = self.enc_t(enc_b)
quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
quant_t, diff_t, id_t = self.quantize_t(quant_t)
quant_t = quant_t.permute(0, 3, 1, 2)
diff_t = diff_t.unsqueeze(0)
dec_t = self.dec_t(quant_t)
enc_b = torch.cat([dec_t, enc_b], 1)
quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
quant_b, diff_b, id_b = self.quantize_b(quant_b)
quant_b = quant_b.permute(0, 3, 1, 2)
diff_b = diff_b.unsqueeze(0)
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
dec = self.dec(quant)
return dec
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute(0, 3, 1, 2)
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute(0, 3, 1, 2)
dec = self.decode(quant_t, quant_b)
return dec
\ No newline at end of file
{
"exp_name" : "lwtvae_l3",
"TrainingLevel":3,
"vaesamples" : 0,
"interpolation":"False",
"interpolation_steps":6,
"exp_name" : "laVAE_latent32",
"TrainingLevel":5,
"vaesamples" : 4,
"interpolation":"True",
"interpolation_steps":22,
"interpolations_dir":"$HOME/data/interpolations/RGBA",
"trainFull":"False",
"trainFull":"True",
"replaceReparametrization":"False",
"learning_rate" : 0.0003,
"batch_size" : 256,
"reload" : "True",
"reload_dir": "$SCRATCH/experiments/shvae_temporalmask/models/lwtvae_l2_again_D20200111T235430",
"epochs" : 400,
"reload" : "False",
"reload_dir": "$SCRATCH/experiments/",
"epochs" : 500,
"loss_log_freq":2,
"img_log_freq":50,
"betavae":1,
"input_dim": 256,
"latent_dim": 64,
"early_stopping": "True",
"weight_decay": "True",
"latent_dim": 32,
"early_stopping": "False",
"weight_decay": "False",
"normalize":"False",
"input_ch":4,
"logdir" : "$SCRATCH/experiments/shvae_temporalmask",
"input_ch":2,
"logdir" : "$SCRATCH/experiments/laVAE_latents",
"dataset_dir":"$SCRATCH/data/mice_tempdiff_medium",
"train_dataset_dir": "$SCRATCH/data/mice_tempdiff_medium/train",
"val_dataset_dir":"$SCRATCH/data/mice_tempdiff_medium/val",
......
......@@ -12,22 +12,16 @@ from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.lwtvae import VAE
from utils.misc import save_checkpoint
from utils.misc import LSIZE, RED_SIZE
## WARNING : THIS SHOULD BE REPLACE WITH PYTORCH 0.5
from utils.learning import EarlyStopping
from utils.learning import ReduceLROnPlateau
from data.loaders import RolloutObservationDataset
from utils.config import get_config_from_json, get_dataset_statistics
from PIL import Image
from datetime import datetime
import json