Commit a155cad2 authored by PhilFischer's avatar PhilFischer
Browse files

Switch dataset and add beta hyperparameter

parent 6f2670bc
......@@ -23,7 +23,7 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
SEED = 163
RANDOM = random.Random(SEED)
DATAPATH = '/cluster/scratch/fischphi/wikiart/wikiart_small'
DATAPATH = '/cluster/scratch/fischphi/wikiart/painters'
########################################################
......@@ -36,7 +36,7 @@ if not os.path.exists(DATAPATH) or not os.path.isdir(DATAPATH):
if not os.path.exists(os.path.join(DATAPATH, 'dataset.npz')):
create_dataset(DATAPATH, size=32)
create_dataset(DATAPATH, size=128)
data, labels, styles = load_data(DATAPATH)
train, test, s_train, s_test = train_test_split(data, labels, test_size=0.2, stratify=labels, random_state=SEED)
assert train.shape[0] == s_train.shape[0] and test.shape[0] == s_test.shape[0]
......@@ -68,20 +68,21 @@ hparams = [
hp.HParam('architecture', hp.Discrete(['σ-VAE-1'])),
hp.HParam('likelihood', hp.Discrete(['sigma'])),
hp.HParam('log_lr', hp.RealInterval(-5.6, -3.6)),
hp.HParam('latent_dim', hp.IntInterval(8, 8)),
hp.HParam('feature_layers', hp.IntInterval(2, 3)),
hp.HParam('log_beta', hp.RealInterval(1.2, 1.2)),
hp.HParam('latent_dim', hp.IntInterval(12, 12)),
hp.HParam('feature_layers', hp.IntInterval(2, 2)),
hp.HParam('style_layers', hp.IntInterval(1, 1)),
hp.HParam('encoder_layers', hp.IntInterval(2, 3)),
hp.HParam('decoder_layers', hp.IntInterval(2, 4)),
hp.HParam('feature_filters', hp.IntInterval(12, 32)),
hp.HParam('encoder_filters', hp.IntInterval(8, 20)),
hp.HParam('decoder_filters', hp.IntInterval(16, 64)),
hp.HParam('encoder_layers', hp.IntInterval(2, 2)),
hp.HParam('decoder_layers', hp.IntInterval(4, 5)),
hp.HParam('feature_filters', hp.IntInterval(8, 32)),
hp.HParam('encoder_filters', hp.IntInterval(8, 32)),
hp.HParam('decoder_filters', hp.IntInterval(8, 32)),
hp.HParam('feature_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('encoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('decoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('feature_multiple', hp.IntInterval(1, 2)),
hp.HParam('encoder_multiple', hp.IntInterval(1, 2)),
hp.HParam('decoder_multiple', hp.IntInterval(1, 2))
hp.HParam('feature_multiple', hp.IntInterval(2, 3)),
hp.HParam('encoder_multiple', hp.IntInterval(2, 4)),
hp.HParam('decoder_multiple', hp.IntInterval(2, 3))
]
metrics = [
hp.Metric('epoch_loss', group='train', display_name='Train ELBO'),
......@@ -91,7 +92,7 @@ metrics = [
]
log_dir = 'logs'
ex_name = os.path.join('ex03', datetime.now().strftime("%y%m%d-%H%M%S"))
ex_name = os.path.join('pex01', datetime.now().strftime("%y%m%d-%H%M%S"))
tuner = Tuner((train, s_train), (test, s_test), define_model, hparams, metrics, log_dir, seed=SEED)
tuner.tune(ex_name, runs=10, epochs=1000)
......
......@@ -23,10 +23,10 @@ def make_kumaraswamy_distr(cdim, reinterpreted_batch_ndims=1):
class VAE(Model):
def __init__(self, shape, n_styles, latent_dim=32, activation='swish', likelihood='sigma',
def __init__(self, shape, n_styles, latent_dim=32, activation='swish', likelihood='sigma', log_beta=0.,
encoder_layers=1, encoder_filters=4, encoder_kernel_size=3, encoder_multiple=1,
decoder_layers=2, decoder_filters=4, decoder_kernel_size=3, decoder_multiple=1,
feature_layers=3, feature_filters=4, feature_kernel_size=3, feature_multiple=1,
feature_layers=3, feature_filters=4, feature_kernel_size=3, feature_multiple=1,
style_layers=1, **kwargs):
super(VAE, self).__init__()
......@@ -38,7 +38,7 @@ class VAE(Model):
# Define encoder and posterior
self._prior = Independent(Normal(tf.zeros(latent_dim), scale=1), reinterpreted_batch_ndims=1)
self._sampler = IndependentNormal(latent_dim, activity_regularizer=KLDivergenceRegularizer(self._prior))
self._sampler = IndependentNormal(latent_dim, activity_regularizer=KLDivergenceRegularizer(self._prior, weight=latent_dim*10**log_beta))
self._encoder = Encoder(latent_dim=self._sampler.params_size(latent_dim), activation=activation,
layers=encoder_layers, filters=encoder_filters, kernel_size=encoder_kernel_size, multiple=encoder_multiple)
......
......@@ -55,7 +55,7 @@ class Tuner():
model = self.model_fn(self.shape, self.n_styles, params)
checkpoint = ModelCheckpoint(weights_dir, save_best_only=True, save_weights_only=True)
stopping = EarlyStopping(patience=30)
stopping = EarlyStopping(patience=100)
tensorboard = TensorBoard(log_dir=run_dir)
hpboard = hp.KerasCallback(run_dir, params)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment