Commit 29ec633f authored by nstorni's avatar nstorni
Browse files

VAE training with Pytorch ImageFolder dataset loader

parent b943b978
......@@ -7,3 +7,6 @@ Dockerfile
frames/
dock_config.json
.mypy_cache
exp_*
lsf.*
lsd.*
......@@ -7,8 +7,10 @@ import torch
import torch.utils.data
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from models.vae import VAE
......@@ -42,6 +44,7 @@ torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
transform_train = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
......@@ -50,25 +53,29 @@ transform_train = transforms.Compose([
])
transform_test = transforms.Compose([
transforms.ToPILImage(),
# transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
transforms.ToTensor(),
])
dataset_train = RolloutObservationDataset('datasets/carracing',
transform_train, train=True)
dataset_test = RolloutObservationDataset('datasets/carracing',
transform_test, train=False)
train_dir = '../data/hymenoptera_data/train'
test_dir = '../data/hymenoptera_data/val'
dataset_train = datasets.ImageFolder(train_dir, transform_test)
dataset_test = datasets.ImageFolder(test_dir, transform_test)
# dataset_train = RolloutObservationDataset('datasets/carracing',
# transform_train, train=True)
# dataset_test = RolloutObservationDataset('datasets/carracing',
# transform_test, train=False)
train_loader = torch.utils.data.DataLoader(
dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, shuffle=True, num_workers=2)
model = VAE(3, LSIZE).to(device)
optimizer = optim.Adam(model.parameters())
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)
earlystopping = EarlyStopping('min', patience=3000)
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logsigma):
......@@ -86,10 +93,14 @@ def loss_function(recon_x, x, mu, logsigma):
def train(epoch):
""" One training epoch """
model.train()
dataset_train.load_next_buffer()
# dataset_test.load_next_buffer()
train_loss = 0
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
# for batch_idx, data in enumerate(train_loader):
# data = data.to(device)
batch_idx = 0
for inputs, labels in train_loader:
batch_idx += len(labels)
data = inputs.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
......@@ -106,19 +117,22 @@ def train(epoch):
epoch, train_loss / len(train_loader.dataset)))
def test():
def test(epoch):
""" One test epoch """
model.eval()
dataset_test.load_next_buffer()
# dataset_test.load_next_buffer()
test_loss = 0
with torch.no_grad():
for data in test_loader:
data = data.to(device)
# for data in test_loader:
for inputs, labels in test_loader:
data = inputs.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).item()
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
# writer.add_scalar('Loss/test', test_loss, epoch)
return test_loss
# check vae dir exists, if not, create it
......@@ -126,6 +140,10 @@ vae_dir = join(args.logdir, 'vae')
if not exists(vae_dir):
mkdir(vae_dir)
mkdir(join(vae_dir, 'samples'))
mkdir(join(vae_dir, 'runs'))
# writer = SummaryWriter(log_dir=join(vae_dir, 'runs'))
reload_file = join(vae_dir, 'best.tar')
if not args.noreload and exists(reload_file):
......@@ -144,9 +162,9 @@ cur_best = None
for epoch in range(1, args.epochs + 1):
train(epoch)
test_loss = test()
test_loss = test(epoch)
scheduler.step(test_loss)
earlystopping.step(test_loss)
# earlystopping.step(test_loss)
# checkpointing
best_filename = join(vae_dir, 'best.tar')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment