From 34922029c51257aa86aeac2d58bbc5d071891fb1 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Thu, 27 May 2021 10:57:16 -0400 Subject: [PATCH 1/6] TESTED: rename variables and outputs to be more readable --- train_variational_autoencoder_jax.py | 158 ++++++++++++--------------- 1 file changed, 69 insertions(+), 89 deletions(-) diff --git a/train_variational_autoencoder_jax.py b/train_variational_autoencoder_jax.py index c0b64a8..e9e9b75 100644 --- a/train_variational_autoencoder_jax.py +++ b/train_variational_autoencoder_jax.py @@ -15,13 +15,14 @@ import optax import tensorflow_datasets as tfds from tensorflow_probability.substrates import jax as tfp -import distrax tfd = tfp.distributions +tfb = tfp.bijectors Batch = Mapping[str, np.ndarray] MNIST_IMAGE_SHAPE: Sequence[int] = (28, 28, 1) PRNGKey = jnp.ndarray +Array = jnp.ndarray def add_args(parser): @@ -31,7 +32,7 @@ def add_args(parser): parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--training_steps", type=int, default=30000) parser.add_argument("--log_interval", type=int, default=10000) - parser.add_argument("--num_eval_samples", type=int, default=128) + parser.add_argument("--num_importance_samples", type=int, default=1000) parser.add_argument("--gpu", default=False, action=argparse.BooleanOptionalAction) parser.add_argument("--random_seed", type=int, default=42) @@ -77,13 +78,18 @@ def __init__( ] ) - def __call__(self, x: jnp.ndarray, z: jnp.ndarray) -> Tuple[tfd.Distribution]: + def __call__(self, x: Array, z: Array) -> Array: + """Compute log probability""" p_z = tfd.Normal( loc=jnp.zeros(self._latent_size), scale=jnp.ones(self._latent_size) ) + # sum over latent dimensions + log_p_z = p_z.log_prob(z).sum(-1) logits = self.generative_network(z) p_x_given_z = tfd.Bernoulli(logits=logits) - return p_z, p_x_given_z + # sum over last three image dimensions (width, height, channels) + log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1)) + return log_p_z + log_p_x_given_z class VariationalMeanField(hk.Module): @@ -111,11 +117,15 @@ def condition(self, inputs): scale = jax.nn.softplus(scale_arg) return loc, scale - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: + def __call__(self, x: Array, num_samples: int) -> Tuple[Array, Array]: + """Compute sample and log probability""" loc, scale = self.condition(x) # IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class q_z = tfd.Normal(loc=loc, scale=scale) - return q_z + z = q_z.sample(sample_shape=[num_samples], seed=hk.next_rng_key()) + # sum over latent dimension + log_q_z = q_z.log_prob(z).sum(-1) + return z, log_q_z def make_conditioner( @@ -138,63 +148,41 @@ def make_conditioner( ) -def make_flow( - event_shape: Sequence[int], - num_layers: int, - hidden_sizes: Sequence[int], - num_bins: int, -) -> distrax.Transformed: - """Creates the flow model.""" - # Alternating binary mask. - mask = jnp.arange(0, np.prod(event_shape)) % 2 - mask = jnp.reshape(mask, event_shape) - mask = mask.astype(bool) - - def bijector_fn(params: jnp.array): - return distrax.RationalQuadraticSpline(params, range_min=0.0, range_max=1.0) - - # Number of parameters for the rational-quadratic spline: - # - `num_bins` bin widths - # - `num_bins` bin heights - # - `num_bins + 1` knot slopes - # for a total of `3 * num_bins + 1` parameters. - num_bijector_params = 3 * num_bins + 1 - - layers = [] - for _ in range(num_layers): - layer = distrax.MaskedCoupling( - mask=mask, - bijector=bijector_fn, - conditioner=make_conditioner( - event_shape, hidden_sizes, num_bijector_params - ), - ) - layers.append(layer) - # Flip the mask after each layer. - mask = jnp.logical_not(mask) - - # We invert the flow so that the `forward` method is called with `log_prob`. - flow = distrax.Inverse(distrax.Chain(layers)) - base_distribution = distrax.MultivariateNormalDiag( - loc=jnp.zeros(event_shape), scale_diag=jnp.ones(event_shape) - ) - return distrax.Transformed(base_distribution, flow) +class FlowSequential(hk.Sequential): + def __call__(self, inputs, *args, **kwargs): + """Calls all layers sequentially to compute sample and log probability.""" + total_log_prob = jnp.zeros_like(inputs) + out = inputs + for i, layer in enumerate(self.layers): + if i == 0: + out, log_prob = layer(out, *args, **kwargs) + else: + out = layer(out) + total_log_prob += log_prob + return out, total_log_prob + + +class InverseAutoregressiveFlow(hk.Module): + """Uses masked autoregressive networks and a shift scale transform. + Follows Algorithm 1 from the Inverse Autoregressive Flow paper, Kingma et al. (2016) https://arxiv.org/abs/1606.04934. + """ -class VariationalFlow(hk.Module): def __init__(self, latent_size: int, hidden_size: int): super().__init__(name="variational") self._latent_size = latent_size self._hidden_size = hidden_size - - def __call__(self, x: jnp.ndarray) -> distrax.Distribution: - return make_flow( - event_shape=(self._latent_size,), - num_layers=2, - hidden_sizes=[self._hidden_size] * 2, - num_bins=4, + self.encoder = hk.MLP( + output_sizes=[hidden_size, hidden_size, latent_size * 3], + activation=jax.nn.relu, + activate_final=False, ) + def __call__(self, x: Array) -> Tuple[Array, Array]: + """Compute sample and log probability.""" + loc, scale_arg, h = jnp.split(self.encoder(x), 3, axis=-1) + return loc, scale_arg + def main(): start_time = time.time() @@ -202,45 +190,44 @@ def main(): add_args(parser) args = parser.parse_args() rng_seq = hk.PRNGSequence(args.random_seed) - model = hk.transform( - lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)(x, z) + p_log_prob = hk.transform( + lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)( + x=x, z=z + ) ) - # variational = hk.transform( - # lambda x: VariationalMeanField(args.latent_size, args.hidden_size)(x) - # ) - variational = hk.transform( - lambda x: VariationalFlow(args.latent_size, args.hidden_size)(x) + q_sample_and_log_prob = hk.transform( + lambda x, num_samples: VariationalMeanField(args.latent_size, args.hidden_size)( + x, num_samples + ) ) - p_params = model.init( + p_params = p_log_prob.init( next(rng_seq), - np.zeros((1, *MNIST_IMAGE_SHAPE)), - np.zeros((1, args.latent_size)), + z=np.zeros((1, args.latent_size)), + x=np.zeros((1, *MNIST_IMAGE_SHAPE)), + ) + q_params = q_sample_and_log_prob.init( + next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)), 1 ) - q_params = variational.init(next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))) params = hk.data_structures.merge(p_params, q_params) optimizer = optax.rmsprop(args.learning_rate) opt_state = optimizer.init(params) - # @jax.jit - def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: + @jax.jit + def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> Array: + """Objective function is negative ELBO.""" x = batch["image"] predicate = lambda module_name, name, value: "model" in module_name p_params, q_params = hk.data_structures.partition(predicate, params) - q_z = variational.apply(q_params, rng_key, x) - z, log_q_z = q_z.sample_and_log_prob(x, sample_shape=[1], seed=rng_key) - p_z, p_x_given_z = model.apply(p_params, rng_key, x, z) - # sum over last three image dimensions (width, height, channels) - log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1)) - # sum over latent dimension - log_p_z = p_z.log_prob(z).sum(axis=-1) - elbo = log_p_x_given_z + log_p_z - log_q_z + z, log_q_z = q_sample_and_log_prob.apply(q_params, rng_key, x=x, num_samples=1) + log_p_x_z = p_log_prob.apply(p_params, rng_key, x=x, z=z) + elbo = log_p_x_z - log_q_z # average elbo over number of samples elbo = elbo.mean(axis=0) # sum elbo over batch elbo = elbo.sum(axis=0) return -elbo - # @jax.jit + @jax.jit def train_step( params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch ) -> Tuple[hk.Params, optax.OptState]: @@ -250,24 +237,17 @@ def train_step( new_params = optax.apply_updates(params, updates) return new_params, new_opt_state - # @jax.jit + @jax.jit def importance_weighted_estimate( params: hk.Params, rng_key: PRNGKey, batch: Batch - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[Array, Array]: """Estimate marginal log p(x) using importance sampling.""" x = batch["image"] - # out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x) predicate = lambda module_name, name, value: "model" in module_name p_params, q_params = hk.data_structures.partition(predicate, params) - q_z = variational.apply(q_params, rng_key, x) - z, log_q_z = q_z.sample_and_log_prob(sample_shape=[args.num_eval_samples], seed=rng_key) - p_z, p_x_given_z = model.apply(p_params, rng_key, x, z) - # log_q_z = q_z.log_prob(z).sum(axis=-1) - # sum over last three image dimensions (width, height, channels) - log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1)) - # sum over latent dimension - log_p_z = p_z.log_prob(z).sum(axis=-1) - elbo = log_p_x_given_z + log_p_z - log_q_z + z, log_q_z = q_sample_and_log_prob.apply(q_params, rng_key, x=x, num_samples=args.num_importance_samples) + log_p_x_z = p_log_prob.apply(p_params, rng_key, x, z) + elbo = log_p_x_z - log_q_z # importance sampling of approximate marginal likelihood with q(z) # as the proposal, and logsumexp in the sample dimension log_p_x = jax.nn.logsumexp(elbo, axis=0) - jnp.log(jnp.shape(elbo)[0]) From d86c60bf0952778d7d02eb88aaa38d225320d9c5 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Tue, 1 Jun 2021 15:46:36 -0400 Subject: [PATCH 2/6] remove partitioning of params tip from https://github.com/deepmind/dm-haiku/issues/128 --- .env | 10 +++-- train_variational_autoencoder_jax.py | 58 +++++++++++++++++++++------- 2 files changed, 51 insertions(+), 17 deletions(-) diff --git a/.env b/.env index 0c97e34..d203f2a 100644 --- a/.env +++ b/.env @@ -1,9 +1,13 @@ # dev.env - development configuration # suppress warnings for jax -JAX_PLATFORM_NAME=cpu +# JAX_PLATFORM_NAME=cpu # suppress tensorflow warnings -TF_CPP_MIN_LOG_LEVEL=2 +TF_CPP_MIN_LOG_LEVEL=3 -TFDS_DATA_DIR=/scratch/gpfs/altosaar/tensorflow_datasets \ No newline at end of file +# set tensorflow data directory +TFDS_DATA_DIR=/scratch/gpfs/altosaar/tensorflow_datasets + +# disable JIT for debugging +JAX_DISABLE_JIT=1 \ No newline at end of file diff --git a/train_variational_autoencoder_jax.py b/train_variational_autoencoder_jax.py index e9e9b75..e269675 100644 --- a/train_variational_autoencoder_jax.py +++ b/train_variational_autoencoder_jax.py @@ -7,6 +7,7 @@ import pathlib from calendar import c from typing import Generator, Mapping, NamedTuple, Sequence, Tuple +from distrax import Inverse import numpy as np import jax @@ -81,7 +82,7 @@ def __init__( def __call__(self, x: Array, z: Array) -> Array: """Compute log probability""" p_z = tfd.Normal( - loc=jnp.zeros(self._latent_size), scale=jnp.ones(self._latent_size) + loc=jnp.zeros(self._latent_size, dtype=jnp.float32), scale=jnp.ones(self._latent_size, dtype=jnp.float32) ) # sum over latent dimensions log_p_z = p_z.log_prob(z).sum(-1) @@ -162,7 +163,7 @@ def __call__(self, inputs, *args, **kwargs): return out, total_log_prob -class InverseAutoregressiveFlow(hk.Module): +class VariationalFlow(hk.Module): """Uses masked autoregressive networks and a shift scale transform. Follows Algorithm 1 from the Inverse Autoregressive Flow paper, Kingma et al. (2016) https://arxiv.org/abs/1606.04934. @@ -172,16 +173,41 @@ def __init__(self, latent_size: int, hidden_size: int): super().__init__(name="variational") self._latent_size = latent_size self._hidden_size = hidden_size - self.encoder = hk.MLP( + self.encoder = hk.nets.MLP( output_sizes=[hidden_size, hidden_size, latent_size * 3], activation=jax.nn.relu, activate_final=False, ) + self.first_block = InverseAutoregressiveFlow(latent_size, hidden_size) + self.second_block = InverseAutoregressiveFlow(latent_size, hidden_size) - def __call__(self, x: Array) -> Tuple[Array, Array]: + def __call__(self, x: Array, num_samples: int) -> Tuple[Array, Array]: """Compute sample and log probability.""" loc, scale_arg, h = jnp.split(self.encoder(x), 3, axis=-1) - return loc, scale_arg + q_z0 = tfd.Normal(loc=loc, scale=jax.nn.softplus(scale_arg)) + z0 = q_z0.sample(sample_shape=[num_samples], seed=hk.next_rng_key()) + log_q_z0 = q_z0.log_prob(z0).sum(-1) + z1, log_det_q_z1 = self.first_block(z0, context=h) + z2, log_det_q_z2 = self.second_block(z1, context=h) + return z2, log_q_z0 + log_det_q_z1 + log_det_q_z2 + + +class InverseAutoregressiveFlow(hk.Module): + def __init__(self, latent_size: int, hidden_size: int): + super().__init__() + self.made = tfb.AutoregressiveNetwork( + params=latent_size, + hidden_units=[hidden_size, hidden_size], + conditional=True, + conditional_event_shape=latent_size, + activation=jax.nn.relu, + ) + + def __call__(self, input: Array, context: Array): + m, s = jnp.split(self.made(input, conditional_input=context), 2, axis=-1) + sigmoid = jax.nn.sigmoid(s) + z = sigmoid * input + (1 - sigmoid) * m + return z, -jax.nn.log_sigmoid(s).sum(-1) def main(): @@ -189,6 +215,8 @@ def main(): parser = argparse.ArgumentParser() add_args(parser) args = parser.parse_args() + print(args) + print("jax_disable_jit: ", jax.config.read('jax_disable_jit')) rng_seq = hk.PRNGSequence(args.random_seed) p_log_prob = hk.transform( lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)( @@ -202,22 +230,23 @@ def main(): ) p_params = p_log_prob.init( next(rng_seq), - z=np.zeros((1, args.latent_size)), - x=np.zeros((1, *MNIST_IMAGE_SHAPE)), + z=np.zeros((1, args.latent_size), dtype=np.float32), + x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), ) q_params = q_sample_and_log_prob.init( - next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)), 1 + next(rng_seq), + x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), + num_samples=1 ) - params = hk.data_structures.merge(p_params, q_params) optimizer = optax.rmsprop(args.learning_rate) + params = (p_params, q_params) opt_state = optimizer.init(params) @jax.jit def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> Array: """Objective function is negative ELBO.""" x = batch["image"] - predicate = lambda module_name, name, value: "model" in module_name - p_params, q_params = hk.data_structures.partition(predicate, params) + p_params, q_params = params z, log_q_z = q_sample_and_log_prob.apply(q_params, rng_key, x=x, num_samples=1) log_p_x_z = p_log_prob.apply(p_params, rng_key, x=x, z=z) elbo = log_p_x_z - log_q_z @@ -243,9 +272,10 @@ def importance_weighted_estimate( ) -> Tuple[Array, Array]: """Estimate marginal log p(x) using importance sampling.""" x = batch["image"] - predicate = lambda module_name, name, value: "model" in module_name - p_params, q_params = hk.data_structures.partition(predicate, params) - z, log_q_z = q_sample_and_log_prob.apply(q_params, rng_key, x=x, num_samples=args.num_importance_samples) + p_params, q_params = params + z, log_q_z = q_sample_and_log_prob.apply( + q_params, rng_key, x=x, num_samples=args.num_importance_samples + ) log_p_x_z = p_log_prob.apply(p_params, rng_key, x, z) elbo = log_p_x_z - log_q_z # importance sampling of approximate marginal likelihood with q(z) From 7e4f8eea5cbaac5977e6abdf606ec307de248475 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Tue, 1 Jun 2021 17:55:36 -0400 Subject: [PATCH 3/6] WIP - bug in elbo for inverse autoregressive flow --- train_variational_autoencoder_jax.py | 185 +++++++++++++++++++-------- 1 file changed, 134 insertions(+), 51 deletions(-) diff --git a/train_variational_autoencoder_jax.py b/train_variational_autoencoder_jax.py index e269675..4413338 100644 --- a/train_variational_autoencoder_jax.py +++ b/train_variational_autoencoder_jax.py @@ -4,13 +4,11 @@ import time import argparse -import pathlib -from calendar import c -from typing import Generator, Mapping, NamedTuple, Sequence, Tuple -from distrax import Inverse +from typing import Generator, Mapping, Sequence, Tuple, Optional import numpy as np import jax +from jax import lax import haiku as hk import jax.numpy as jnp import optax @@ -23,7 +21,6 @@ Batch = Mapping[str, np.ndarray] MNIST_IMAGE_SHAPE: Sequence[int] = (28, 28, 1) PRNGKey = jnp.ndarray -Array = jnp.ndarray def add_args(parser): @@ -79,10 +76,11 @@ def __init__( ] ) - def __call__(self, x: Array, z: Array) -> Array: + def __call__(self, x: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray: """Compute log probability""" p_z = tfd.Normal( - loc=jnp.zeros(self._latent_size, dtype=jnp.float32), scale=jnp.ones(self._latent_size, dtype=jnp.float32) + loc=jnp.zeros(self._latent_size, dtype=jnp.float32), + scale=jnp.ones(self._latent_size, dtype=jnp.float32), ) # sum over latent dimensions log_p_z = p_z.log_prob(z).sum(-1) @@ -118,7 +116,9 @@ def condition(self, inputs): scale = jax.nn.softplus(scale_arg) return loc, scale - def __call__(self, x: Array, num_samples: int) -> Tuple[Array, Array]: + def __call__( + self, x: jnp.ndarray, num_samples: int + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute sample and log probability""" loc, scale = self.condition(x) # IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class @@ -129,26 +129,6 @@ def __call__(self, x: Array, num_samples: int) -> Tuple[Array, Array]: return z, log_q_z -def make_conditioner( - event_shape: Sequence[int], hidden_sizes: Sequence[int], num_bijector_params: int -) -> hk.Sequential: - """Creates an MLP conditioner for each layer of the flow.""" - return hk.Sequential( - [ - hk.Flatten(preserve_dims=-len(event_shape)), - hk.nets.MLP(hidden_sizes, activate_final=True), - # We initialize this linear layer to zero so that the flow is initialized - # to the identity function. - hk.Linear( - np.prod(event_shape) * num_bijector_params, - w_init=jnp.zeros, - b_init=jnp.zeros, - ), - hk.Reshape(tuple(event_shape) + (num_bijector_params,), preserve_dims=-1), - ] - ) - - class FlowSequential(hk.Sequential): def __call__(self, inputs, *args, **kwargs): """Calls all layers sequentially to compute sample and log probability.""" @@ -171,42 +151,145 @@ class VariationalFlow(hk.Module): def __init__(self, latent_size: int, hidden_size: int): super().__init__(name="variational") - self._latent_size = latent_size - self._hidden_size = hidden_size - self.encoder = hk.nets.MLP( - output_sizes=[hidden_size, hidden_size, latent_size * 3], - activation=jax.nn.relu, - activate_final=False, + self.encoder = hk.Sequential( + [ + hk.Flatten(), + hk.Linear(hidden_size), + jax.nn.relu, + hk.Linear(hidden_size), + jax.nn.relu, + hk.Linear(latent_size * 3), + ] ) self.first_block = InverseAutoregressiveFlow(latent_size, hidden_size) self.second_block = InverseAutoregressiveFlow(latent_size, hidden_size) - def __call__(self, x: Array, num_samples: int) -> Tuple[Array, Array]: + def __call__( + self, x: jnp.ndarray, num_samples: int + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute sample and log probability.""" loc, scale_arg, h = jnp.split(self.encoder(x), 3, axis=-1) q_z0 = tfd.Normal(loc=loc, scale=jax.nn.softplus(scale_arg)) z0 = q_z0.sample(sample_shape=[num_samples], seed=hk.next_rng_key()) + h = jnp.expand_dims(h, axis=0) # needed for the new sample dimension in z0 log_q_z0 = q_z0.log_prob(z0).sum(-1) z1, log_det_q_z1 = self.first_block(z0, context=h) z2, log_det_q_z2 = self.second_block(z1, context=h) return z2, log_q_z0 + log_det_q_z1 + log_det_q_z2 +class MaskedLinear(hk.Module): + """Masked Linear module. + + TODO: fix initialization according to number of inputs per unit + (can compute this from the mask). + """ + + def __init__( + self, + mask: jnp.ndarray, + output_size: int, + with_bias: bool = True, + w_init: Optional[hk.initializers.Initializer] = None, + b_init: Optional[hk.initializers.Initializer] = None, + name: Optional[str] = None, + ): + super().__init__(name=name) + self.input_size = None + self.output_size = output_size + self.with_bias = with_bias + self.w_init = w_init + self.b_init = b_init or jnp.zeros + self._mask = mask + + def __call__( + self, + inputs: jnp.ndarray, + *, + precision: Optional[lax.Precision] = None, + ) -> jnp.ndarray: + """Computes a masked linear transform of the input.""" + if not inputs.shape: + raise ValueError("Input must not be scalar.") + + input_size = self.input_size = inputs.shape[-1] + output_size = self.output_size + dtype = inputs.dtype + + w_init = self.w_init + if w_init is None: + stddev = 1.0 / np.sqrt(self.input_size) + w_init = hk.initializers.TruncatedNormal(stddev=stddev) + w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init) + + out = jnp.dot(inputs, w * self._mask, precision=precision) + + if self.with_bias: + b = hk.get_parameter("b", [self.output_size], dtype, init=self.b_init) + b = jnp.broadcast_to(b, out.shape) + out = out + b + + return out + + +class MaskedAndConditionalLinear(hk.Module): + """Assumes the conditional inputs have same size as inputs.""" + + def __init__(self, mask: jnp.ndarray, output_size: int): + super().__init__() + self.masked_linear = MaskedLinear(mask, output_size) + self.conditional_linear = hk.Linear(output_size, with_bias=False) + + def __call__( + self, inputs: jnp.ndarray, conditional_inputs: jnp.ndarray + ) -> jnp.ndarray: + return self.masked_linear(inputs) + self.conditional_linear(conditional_inputs) + + +class MADE(hk.Module): + """Masked Autoregressive Distribution Estimator. + + From https://arxiv.org/abs/1502.03509 + + conditional_input specifies whether every layer of the network will be + conditioned on an additional input. + The additional input is conditioned on using a linear transformation + (that does not use a mask) + """ + + def __init__(self, input_size: int, hidden_size: int, num_outputs_per_input: int): + super().__init__() + masks = tfb.masked_autoregressive._make_dense_autoregressive_masks( + params=num_outputs_per_input, # shift and log scale are "parameters"; non-standard naming + event_size=input_size, + hidden_units=[hidden_size, hidden_size], + input_order="left-to-right", + hidden_degrees="equal", + ) + self._input_size = input_size + self.first_net = MaskedAndConditionalLinear(masks[0], hidden_size) + self.second_net = MaskedAndConditionalLinear(masks[1], hidden_size) + # multiply by two for the shift and log scale + self.final_net = MaskedAndConditionalLinear(masks[2], input_size * 2) + + def __call__(self, inputs, conditional_inputs): + outputs = jax.nn.relu(self.first_net(inputs, conditional_inputs)) + outputs = jax.nn.relu(self.second_net(outputs, conditional_inputs)) + return self.final_net(outputs, conditional_inputs) + + class InverseAutoregressiveFlow(hk.Module): def __init__(self, latent_size: int, hidden_size: int): super().__init__() - self.made = tfb.AutoregressiveNetwork( - params=latent_size, - hidden_units=[hidden_size, hidden_size], - conditional=True, - conditional_event_shape=latent_size, - activation=jax.nn.relu, + # two outputs per latent input: shift and log scale parameter + self.made = MADE( + input_size=latent_size, hidden_size=hidden_size, num_outputs_per_input=2 ) - def __call__(self, input: Array, context: Array): - m, s = jnp.split(self.made(input, conditional_input=context), 2, axis=-1) + def __call__(self, inputs: jnp.ndarray, context: jnp.ndarray): + m, s = jnp.split(self.made(inputs, conditional_inputs=context), 2, axis=-1) sigmoid = jax.nn.sigmoid(s) - z = sigmoid * input + (1 - sigmoid) * m + z = sigmoid * inputs + (1 - sigmoid) * m return z, -jax.nn.log_sigmoid(s).sum(-1) @@ -216,7 +299,7 @@ def main(): add_args(parser) args = parser.parse_args() print(args) - print("jax_disable_jit: ", jax.config.read('jax_disable_jit')) + print("jax_disable_jit: ", jax.config.read("jax_disable_jit")) rng_seq = hk.PRNGSequence(args.random_seed) p_log_prob = hk.transform( lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)( @@ -224,7 +307,7 @@ def main(): ) ) q_sample_and_log_prob = hk.transform( - lambda x, num_samples: VariationalMeanField(args.latent_size, args.hidden_size)( + lambda x, num_samples: VariationalFlow(args.latent_size, args.hidden_size)( x, num_samples ) ) @@ -234,16 +317,16 @@ def main(): x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), ) q_params = q_sample_and_log_prob.init( - next(rng_seq), - x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), - num_samples=1 + next(rng_seq), + x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), + num_samples=1, ) optimizer = optax.rmsprop(args.learning_rate) params = (p_params, q_params) opt_state = optimizer.init(params) @jax.jit - def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> Array: + def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: """Objective function is negative ELBO.""" x = batch["image"] p_params, q_params = params @@ -269,7 +352,7 @@ def train_step( @jax.jit def importance_weighted_estimate( params: hk.Params, rng_key: PRNGKey, batch: Batch - ) -> Tuple[Array, Array]: + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Estimate marginal log p(x) using importance sampling.""" x = batch["image"] p_params, q_params = params From 6fc12c802293e067f5a51235d4eb735f1c76f3c8 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Wed, 2 Jun 2021 12:55:45 -0400 Subject: [PATCH 4/6] TESTED: jax matches pytorch for IAF and VAE --- README.md | 16 +- flow.py | 247 +++++++++++------------ masks.py | 181 +++++++++++++++++ train_variational_autoencoder_jax.py | 74 +++---- train_variational_autoencoder_pytorch.py | 4 +- 5 files changed, 353 insertions(+), 169 deletions(-) create mode 100644 masks.py diff --git a/README.md b/README.md index c7c651f..deaab34 100644 --- a/README.md +++ b/README.md @@ -44,10 +44,22 @@ step: 30000 valid elbo: -103.76 valid log p(x): -97.71 Using jax (anaconda environment is in `environment-jax.yml`), to get a 3x speedup over pytorch: ``` -$ python train_variational_autoencoder_jax.py --gpu +$ python train_variational_autoencoder_jax.py --variational mean-field Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/s Step 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/s Step 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/s Step 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716 Total time: 0.810 minutes -``` \ No newline at end of file +``` + +Inverse autoregressive flow in jax: +``` +$ python train_variational_autoencoder_jax.py --variational flow +Step 0 Train ELBO estimate: -727.404 Validation ELBO estimate: -726.977 Validation log p(x) estimate: -713.389 Speed: 2.56e+11 examples/s +Step 10000 Train ELBO estimate: -100.093 Validation ELBO estimate: -106.985 Validation log p(x) estimate: -99.565 Speed: 2.57e+04 examples/s +Step 20000 Train ELBO estimate: -113.073 Validation ELBO estimate: -108.057 Validation log p(x) estimate: -98.841 Speed: 3.37e+04 examples/s +Step 29999 Test ELBO estimate: -106.803 Test log p(x) estimate: -97.620 +Total time: 2.350 minutes +``` + +(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.) \ No newline at end of file diff --git a/flow.py b/flow.py index b3b33e7..16751b2 100644 --- a/flow.py +++ b/flow.py @@ -4,149 +4,134 @@ import torch.nn as nn from torch.nn import functional as F +import masks + class InverseAutoregressiveFlow(nn.Module): - """Inverse Autoregressive Flows with LSTM-type update. One block. - - Eq 11-14 of https://arxiv.org/abs/1606.04934 - """ - def __init__(self, num_input, num_hidden, num_context): - super().__init__() - self.made = MADE(num_input=num_input, num_output=num_input * 2, - num_hidden=num_hidden, num_context=num_context) - # init such that sigmoid(s) is close to 1 for stability - self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) - self.sigmoid = nn.Sigmoid() - self.log_sigmoid = nn.LogSigmoid() - - def forward(self, input, context=None): - m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) - s = s + self.sigmoid_arg_bias - sigmoid = self.sigmoid(s) - z = sigmoid * input + (1 - sigmoid) * m - return z, -self.log_sigmoid(s) + """Inverse Autoregressive Flows with LSTM-type update. One block. + + Eq 11-14 of https://arxiv.org/abs/1606.04934 + """ + + def __init__(self, num_input, num_hidden, num_context): + super().__init__() + self.made = MADE( + num_input=num_input, + num_outputs_per_input=2, + num_hidden=num_hidden, + num_context=num_context, + ) + # init such that sigmoid(s) is close to 1 for stability + self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) + self.sigmoid = nn.Sigmoid() + self.log_sigmoid = nn.LogSigmoid() + + def forward(self, input, context=None): + m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) + s = s + self.sigmoid_arg_bias + sigmoid = self.sigmoid(s) + z = sigmoid * input + (1 - sigmoid) * m + return z, -self.log_sigmoid(s) class FlowSequential(nn.Sequential): - """Forward pass.""" + """Forward pass.""" - def forward(self, input, context=None): - total_log_prob = torch.zeros_like(input, device=input.device) - for block in self._modules.values(): - input, log_prob = block(input, context) - total_log_prob += log_prob - return input, total_log_prob + def forward(self, input, context=None): + total_log_prob = torch.zeros_like(input, device=input.device) + for block in self._modules.values(): + input, log_prob = block(input, context) + total_log_prob += log_prob + return input, total_log_prob class MaskedLinear(nn.Module): - """Linear layer with some input-output connections masked.""" - def __init__(self, in_features, out_features, mask, context_features=None, bias=True): - super().__init__() - self.linear = nn.Linear(in_features, out_features, bias) - self.register_buffer("mask", mask) - if context_features is not None: - self.cond_linear = nn.Linear(context_features, out_features, bias=False) - - def forward(self, input, context=None): - output = F.linear(input, self.mask * self.linear.weight, self.linear.bias) - if context is None: - return output - else: - return output + self.cond_linear(context) + """Linear layer with some input-output connections masked.""" + def __init__( + self, in_features, out_features, mask, context_features=None, bias=True + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + self.register_buffer("mask", mask) + if context_features is not None: + self.cond_linear = nn.Linear(context_features, out_features, bias=False) -class MADE(nn.Module): - """Implements MADE: Masked Autoencoder for Distribution Estimation. - - Follows https://arxiv.org/abs/1502.03509 - - This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). - """ - def __init__(self, num_input, num_output, num_hidden, num_context): - super().__init__() - # m corresponds to m(k), the maximum degree of a node in the MADE paper - self._m = [] - self._masks = [] - self._build_masks(num_input, num_output, num_hidden, num_layers=3) - self._check_masks() - modules = [] - self.input_context_net = MaskedLinear(num_input, num_hidden, self._masks[0], num_context) - modules.append(nn.ReLU()) - modules.append(MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None)) - modules.append(nn.ReLU()) - modules.append(MaskedLinear(num_hidden, num_output, self._masks[2], context_features=None)) - self.net = nn.Sequential(*modules) - - def _build_masks(self, num_input, num_output, num_hidden, num_layers): - """Build the masks according to Eq 12 and 13 in the MADE paper.""" - rng = np.random.RandomState(0) - # assign input units a number between 1 and D - self._m.append(np.arange(1, num_input + 1)) - for i in range(1, num_layers + 1): - # randomly assign maximum number of input nodes to connect to - if i == num_layers: - # assign output layer units a number between 1 and D - m = np.arange(1, num_input + 1) - assert num_output % num_input == 0, "num_output must be multiple of num_input" - self._m.append(np.hstack([m for _ in range(num_output // num_input)])) - else: - # assign hidden layer units a number between 1 and D-1 - self._m.append(rng.randint(1, num_input, size=num_hidden)) - #self._m.append(np.arange(1, num_hidden + 1) % (num_input - 1) + 1) - if i == num_layers: - mask = self._m[i][None, :] > self._m[i - 1][:, None] - else: - # input to hidden & hidden to hidden - mask = self._m[i][None, :] >= self._m[i - 1][:, None] - # need to transpose for torch linear layer, shape (num_output, num_input) - self._masks.append(torch.from_numpy(mask.astype(np.float32).T)) - - def _check_masks(self): - """Check that the connectivity matrix between layers is lower triangular.""" - # (num_input, num_hidden) - prev = self._masks[0].t() - for i in range(1, len(self._masks)): - # num_hidden is second axis - prev = prev @ self._masks[i].t() - final = prev.numpy() - num_input = self._masks[0].shape[1] - num_output = self._masks[-1].shape[0] - assert final.shape == (num_input, num_output) - if num_output == num_input: - assert np.triu(final).all() == 0 - else: - for submat in np.split(final, - indices_or_sections=num_output // num_input, - axis=1): - assert np.triu(submat).all() == 0 - - def forward(self, input, context=None): - # first hidden layer receives input and context - hidden = self.input_context_net(input, context) - # rest of the network is conditioned on both input and context - return self.net(hidden) + def forward(self, input, context=None): + output = F.linear(input, self.mask * self.linear.weight, self.linear.bias) + if context is None: + return output + else: + return output + self.cond_linear(context) +class MADE(nn.Module): + """Implements MADE: Masked Autoencoder for Distribution Estimation. + + Follows https://arxiv.org/abs/1502.03509 + + This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). + """ + + def __init__(self, num_input, num_outputs_per_input, num_hidden, num_context): + super().__init__() + # m corresponds to m(k), the maximum degree of a node in the MADE paper + self._m = [] + degrees = masks.create_degrees( + input_size=num_input, + hidden_units=[num_hidden] * 2, + input_order="left-to-right", + hidden_degrees="equal", + ) + self._masks = masks.create_masks(degrees) + self._masks[-1] = np.hstack( + [self._masks[-1] for _ in range(num_outputs_per_input)] + ) + self._masks = [torch.from_numpy(m.T) for m in self._masks] + modules = [] + self.input_context_net = MaskedLinear( + num_input, num_hidden, self._masks[0], num_context + ) + self.net = nn.Sequential( + nn.ReLU(), + MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None), + nn.ReLU(), + MaskedLinear( + num_hidden, + num_outputs_per_input * num_input, + self._masks[2], + context_features=None, + ), + ) + + def forward(self, input, context=None): + # first hidden layer receives input and context + hidden = self.input_context_net(input, context) + # rest of the network is conditioned on both input and context + return self.net(hidden) + class Reverse(nn.Module): - """ An implementation of a reversing layer from - Density estimation using Real NVP - (https://arxiv.org/abs/1605.08803). - - From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py - """ - - def __init__(self, num_input): - super(Reverse, self).__init__() - self.perm = np.array(np.arange(0, num_input)[::-1]) - self.inv_perm = np.argsort(self.perm) - - def forward(self, inputs, context=None, mode='forward'): - if mode == "forward": - return inputs[:, :, self.perm], torch.zeros_like(inputs, device=inputs.device) - elif mode == "inverse": - return inputs[:, :, self.inv_perm], torch.zeros_like(inputs, device=inputs.device) - else: - raise ValueError("Mode must be one of {forward, inverse}.") - - + """An implementation of a reversing layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + + From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py + """ + + def __init__(self, num_input): + super(Reverse, self).__init__() + self.perm = np.array(np.arange(0, num_input)[::-1]) + self.inv_perm = np.argsort(self.perm) + + def forward(self, inputs, context=None, mode="forward"): + if mode == "forward": + return inputs[:, :, self.perm], torch.zeros_like( + inputs, device=inputs.device + ) + elif mode == "inverse": + return inputs[:, :, self.inv_perm], torch.zeros_like( + inputs, device=inputs.device + ) + else: + raise ValueError("Mode must be one of {forward, inverse}.") diff --git a/masks.py b/masks.py new file mode 100644 index 0000000..e8c7b17 --- /dev/null +++ b/masks.py @@ -0,0 +1,181 @@ +import numpy as np + +"""Use utility functions from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/masked_autoregressive.py +""" + + +def create_input_order(input_size, input_order="left-to-right"): + """Returns a degree vectors for the input.""" + if input_order == "left-to-right": + return np.arange(start=1, stop=input_size + 1) + elif input_order == "right-to-left": + return np.arange(start=input_size, stop=0, step=-1) + elif input_order == "random": + ret = np.arange(start=1, stop=input_size + 1) + np.random.shuffle(ret) + return ret + + +def create_degrees( + input_size, hidden_units, input_order="left-to-right", hidden_degrees="equal" +): + input_order = create_input_order(input_size, input_order) + degrees = [input_order] + for units in hidden_units: + if hidden_degrees == "random": + # samples from: [low, high) + degrees.append( + np.random.randint( + low=min(np.min(degrees[-1]), input_size - 1), + high=input_size, + size=units, + ) + ) + elif hidden_degrees == "equal": + min_degree = min(np.min(degrees[-1]), input_size - 1) + degrees.append( + np.maximum( + min_degree, + # Evenly divide the range `[1, input_size - 1]` in to `units + 1` + # segments, and pick the boundaries between the segments as degrees. + np.ceil( + np.arange(1, units + 1) * (input_size - 1) / float(units + 1) + ).astype(np.int32), + ) + ) + return degrees + + +def create_masks(degrees): + """Returns a list of binary mask matrices enforcing autoregressivity.""" + return [ + # Create input->hidden and hidden->hidden masks. + inp[:, np.newaxis] <= out + for inp, out in zip(degrees[:-1], degrees[1:]) + ] + [ + # Create hidden->output mask. + degrees[-1][:, np.newaxis] + < degrees[0] + ] + + +def check_masks(masks): + """Check that the connectivity matrix between layers is lower triangular.""" + # (num_input, num_hidden) + prev = masks[0].t() + for i in range(1, len(masks)): + # num_hidden is second axis + prev = prev @ masks[i].t() + final = prev.numpy() + num_input = masks[0].shape[1] + num_output = masks[-1].shape[0] + assert final.shape == (num_input, num_output) + if num_output == num_input: + assert np.triu(final).all() == 0 + else: + for submat in np.split( + final, indices_or_sections=num_output // num_input, axis=1 + ): + assert np.triu(submat).all() == 0 + + +def build_random_masks(num_input, num_output, num_hidden, num_layers): + """Build the masks according to Eq 12 and 13 in the MADE paper.""" + # assign input units a number between 1 and D + rng = np.random.RandomState(0) + m_list, masks = [], [] + m_list.append(np.arange(1, num_input + 1)) + for i in range(1, num_layers + 1): + if i == num_layers: + # assign output layer units a number between 1 and D + m = np.arange(1, num_input + 1) + assert ( + num_output % num_input == 0 + ), "num_output must be multiple of num_input" + m_list.append(np.hstack([m for _ in range(num_output // num_input)])) + else: + # assign hidden layer units a number between 1 and D-1 + # i.e. randomly assign maximum number of input nodes to connect to + m_list.append(rng.randint(1, num_input, size=num_hidden)) + if i == num_layers: + mask = m_list[i][None, :] > m_list[i - 1][:, None] + else: + # input to hidden & hidden to hidden + mask = m_list[i][None, :] >= m_list[i - 1][:, None] + # need to transpose for torch linear layer, shape (num_output, num_input) + masks.append(mask.astype(np.float32).T) + return masks + + +def _compute_neighborhood(system_size): + """Compute (system_size, neighborhood_size) array.""" + num_variables = system_size ** 2 + arange = np.arange(num_variables) + grid = arange.reshape((system_size, system_size)) + self_and_neighbors = np.zeros((system_size, system_size, 5), dtype=int) + # four nearest-neighbors + self_and_neighbors = np.zeros((system_size, system_size, 5), dtype=int) + self_and_neighbors[..., 0] = grid + neighbor_index = 1 + for axis in [0, 1]: + for shift in [-1, 1]: + self_and_neighbors[..., neighbor_index] = np.roll( + grid, shift=shift, axis=axis + ) + neighbor_index += 1 + # reshape to (num_latent, num_neighbors) + self_and_neighbors = self_and_neighbors.reshape(num_variables, -1) + return self_and_neighbors + + +def build_neighborhood_indicator(system_size): + """Boolean indicator of (num_variables, num_variables) for whether nodes are neighbors.""" + neighborhood = _compute_neighborhood(system_size) + num_variables = system_size ** 2 + mask = np.zeros((num_variables, num_variables), dtype=bool) + for i in range(len(mask)): + mask[i, neighborhood[i]] = True + return mask + + +def build_deterministic_mask(num_variables, num_input, num_output, mask_type): + if mask_type == "input": + in_degrees = np.arange(num_input) % num_variables + else: + in_degrees = np.arange(num_input) % (num_variables - 1) + + if mask_type == "output": + out_degrees = np.arange(num_output) % num_variables + mask = np.expand_dims(out_degrees, -1) > np.expand_dims(in_degrees, 0) + else: + out_degrees = np.arange(num_output) % (num_variables - 1) + mask = np.expand_dims(out_degrees, -1) >= np.expand_dims(in_degrees, 0) + + return mask, in_degrees, out_degrees + + +def build_masks(num_variables, num_input, num_output, num_hidden, mask_fn): + input_mask, _, _ = mask_fn(num_variables, num_input, num_hidden, "input") + hidden_mask, _, _ = mask_fn(num_variables, num_hidden, num_hidden, "hidden") + output_mask, _, _ = mask_fn(num_variables, num_hidden, num_output, "output") + masks = [input_mask, hidden_mask, output_mask] + masks = [torch.from_numpy(x.astype(np.float32)) for x in masks] + return masks + + +def build_neighborhood_mask(num_variables, num_input, num_output, mask_type): + system_size = int(np.sqrt(num_variables)) + # return context mask for input, with same assignment of m(k) maximum node degree + mask, in_degrees, out_degrees = build_deterministic_mask( + system_size ** 2, num_input, num_output, mask_type + ) + neighborhood = _compute_neighborhood(system_size) + neighborhood_mask = np.zeros_like(mask) # shape len(out_degrees), len(in_degrees) + for i in range(len(neighborhood_mask)): + neighborhood_indicator = np.isin(in_degrees, neighborhood[out_degrees[i]]) + neighborhood_mask[i, neighborhood_indicator] = True + return mask * neighborhood_mask, in_degrees, out_degrees + + +def checkerboard(shape): + return (np.indices(shape).sum(0) % 2).astype(np.float32) diff --git a/train_variational_autoencoder_jax.py b/train_variational_autoencoder_jax.py index 4413338..7c023c8 100644 --- a/train_variational_autoencoder_jax.py +++ b/train_variational_autoencoder_jax.py @@ -15,6 +15,8 @@ import tensorflow_datasets as tfds from tensorflow_probability.substrates import jax as tfp +import masks + tfd = tfp.distributions tfb = tfp.bijectors @@ -24,6 +26,7 @@ def add_args(parser): + parser.add_argument("--variational", choices=["flow", "mean-field"]) parser.add_argument("--latent_size", type=int, default=128) parser.add_argument("--hidden_size", type=int, default=512) parser.add_argument("--learning_rate", type=float, default=0.001) @@ -31,7 +34,6 @@ def add_args(parser): parser.add_argument("--training_steps", type=int, default=30000) parser.add_argument("--log_interval", type=int, default=10000) parser.add_argument("--num_importance_samples", type=int, default=1000) - parser.add_argument("--gpu", default=False, action=argparse.BooleanOptionalAction) parser.add_argument("--random_seed", type=int, default=42) @@ -129,20 +131,6 @@ def __call__( return z, log_q_z -class FlowSequential(hk.Sequential): - def __call__(self, inputs, *args, **kwargs): - """Calls all layers sequentially to compute sample and log probability.""" - total_log_prob = jnp.zeros_like(inputs) - out = inputs - for i, layer in enumerate(self.layers): - if i == 0: - out, log_prob = layer(out, *args, **kwargs) - else: - out = layer(out) - total_log_prob += log_prob - return out, total_log_prob - - class VariationalFlow(hk.Module): """Uses masked autoregressive networks and a shift scale transform. @@ -158,7 +146,7 @@ def __init__(self, latent_size: int, hidden_size: int): jax.nn.relu, hk.Linear(hidden_size), jax.nn.relu, - hk.Linear(latent_size * 3), + hk.Linear(latent_size * 3, w_init=jnp.zeros, b_init=jnp.zeros), ] ) self.first_block = InverseAutoregressiveFlow(latent_size, hidden_size) @@ -235,10 +223,10 @@ def __call__( class MaskedAndConditionalLinear(hk.Module): """Assumes the conditional inputs have same size as inputs.""" - def __init__(self, mask: jnp.ndarray, output_size: int): + def __init__(self, mask: jnp.ndarray, output_size: int, **kwargs): super().__init__() - self.masked_linear = MaskedLinear(mask, output_size) - self.conditional_linear = hk.Linear(output_size, with_bias=False) + self.masked_linear = MaskedLinear(mask, output_size, **kwargs) + self.conditional_linear = hk.Linear(output_size, with_bias=False, **kwargs) def __call__( self, inputs: jnp.ndarray, conditional_inputs: jnp.ndarray @@ -259,36 +247,50 @@ class MADE(hk.Module): def __init__(self, input_size: int, hidden_size: int, num_outputs_per_input: int): super().__init__() - masks = tfb.masked_autoregressive._make_dense_autoregressive_masks( - params=num_outputs_per_input, # shift and log scale are "parameters"; non-standard naming - event_size=input_size, - hidden_units=[hidden_size, hidden_size], + self._num_outputs_per_input = num_outputs_per_input + degrees = masks.create_degrees( + input_size=input_size, + hidden_units=[hidden_size] * 2, input_order="left-to-right", hidden_degrees="equal", ) + self._masks = masks.create_masks(degrees) + self._masks[-1] = np.hstack( + [self._masks[-1] for _ in range(num_outputs_per_input)] + ) self._input_size = input_size - self.first_net = MaskedAndConditionalLinear(masks[0], hidden_size) - self.second_net = MaskedAndConditionalLinear(masks[1], hidden_size) + self._first_net = MaskedAndConditionalLinear(self._masks[0], hidden_size) + self._second_net = MaskedAndConditionalLinear(self._masks[1], hidden_size) # multiply by two for the shift and log scale - self.final_net = MaskedAndConditionalLinear(masks[2], input_size * 2) + # initialize weights and biases to zero to init to the identity function + self._final_net = MaskedAndConditionalLinear( + self._masks[2], + input_size * num_outputs_per_input, + w_init=jnp.zeros, + b_init=jnp.zeros, + ) def __call__(self, inputs, conditional_inputs): - outputs = jax.nn.relu(self.first_net(inputs, conditional_inputs)) - outputs = jax.nn.relu(self.second_net(outputs, conditional_inputs)) - return self.final_net(outputs, conditional_inputs) + outputs = jax.nn.relu(self._first_net(inputs, conditional_inputs)) + outputs = outputs[::-1] # reverse + outputs = jax.nn.relu(self._second_net(outputs, conditional_inputs)) + outputs = outputs[::-1] # reverse + outputs = self._final_net(outputs, conditional_inputs) + return jnp.split(outputs, self._num_outputs_per_input, axis=-1) class InverseAutoregressiveFlow(hk.Module): def __init__(self, latent_size: int, hidden_size: int): super().__init__() # two outputs per latent input: shift and log scale parameter - self.made = MADE( + self._made = MADE( input_size=latent_size, hidden_size=hidden_size, num_outputs_per_input=2 ) def __call__(self, inputs: jnp.ndarray, context: jnp.ndarray): - m, s = jnp.split(self.made(inputs, conditional_inputs=context), 2, axis=-1) - sigmoid = jax.nn.sigmoid(s) + m, s = self._made(inputs, conditional_inputs=context) + # initialize sigmoid argument bias so the output is close to 1 + sigmoid = jax.nn.sigmoid(s + 2.0) z = sigmoid * inputs + (1 - sigmoid) * m return z, -jax.nn.log_sigmoid(s).sum(-1) @@ -299,15 +301,19 @@ def main(): add_args(parser) args = parser.parse_args() print(args) - print("jax_disable_jit: ", jax.config.read("jax_disable_jit")) + print("Is jax using @jit decorators?", not jax.config.read("jax_disable_jit")) rng_seq = hk.PRNGSequence(args.random_seed) p_log_prob = hk.transform( lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)( x=x, z=z ) ) + if args.variational == "mean-field": + variational = VariationalMeanField + elif args.variational == "flow": + variational = VariationalFlow q_sample_and_log_prob = hk.transform( - lambda x, num_samples: VariationalFlow(args.latent_size, args.hidden_size)( + lambda x, num_samples: variational(args.latent_size, args.hidden_size)( x, num_samples ) ) diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 3489835..be6d4e2 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -23,9 +23,9 @@ def add_args(parser): parser.add_argument("--learning_rate", type=float, default=0.001) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--test_batch_size", type=int, default=512) - parser.add_argument("--max_iterations", type=int, default=100000) + parser.add_argument("--max_iterations", type=int, default=30000) parser.add_argument("--log_interval", type=int, default=10000) - parser.add_argument("--n_samples", type=int, default=128) + parser.add_argument("--n_samples", type=int, default=1000) parser.add_argument("--use_gpu", action="store_true") parser.add_argument("--seed", type=int, default=582838) parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp") From 2dd2a786edf8018ffe7bf5c2d5713e802cfbc198 Mon Sep 17 00:00:00 2001 From: Jaan Altosaar Date: Thu, 11 Nov 2021 12:51:34 -0500 Subject: [PATCH 5/6] Update README.md --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index deaab34..cf47eef 100644 --- a/README.md +++ b/README.md @@ -62,4 +62,10 @@ Step 29999 Test ELBO estimate: -106.803 Test log p(x) estimate: -97.620 Total time: 2.350 minutes ``` -(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.) \ No newline at end of file +(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.) + +# Generating the GIFs + +1. Run `python train_variational_autoencoder_tensorflow.py` +2. Install imagemagick (homebrew for Mac: https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows: https://community.chocolatey.org/packages/imagemagick.app) +3. Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif: `convert -delay 20 -loop 0 *.jpg latent-space.gif` From 7e8c661d1d91d415c0222a96f791ee64b7ce7a23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jaan=20L=C4=B1=20=E6=9D=8E=20PhD?= Date: Wed, 24 Apr 2024 07:45:21 -0600 Subject: [PATCH 6/6] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index cf47eef..012cef1 100644 --- a/README.md +++ b/README.md @@ -69,3 +69,8 @@ Total time: 2.350 minutes 1. Run `python train_variational_autoencoder_tensorflow.py` 2. Install imagemagick (homebrew for Mac: https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows: https://community.chocolatey.org/packages/imagemagick.app) 3. Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif: `convert -delay 20 -loop 0 *.jpg latent-space.gif` +4. + +## TODO (help needed - feel free to send a PR!) +- add multiple GPU / TPU option +- add jaxtyping support for PyTorch and Jax implementations :) for runtime static type checking (using @beartype decorators)