Skip to content

Notes on using tensorflow-probability in bVAE

These notes are a little bit outdated now. I need to review.

Tensorflow-probability (tfp) provides a few tools that simplify writing a VAE.

  • In the VAE using vanilla tensorflow, the input to our decoder uses a a trick that mimics drawing a sample from a distribution parameterized by our latent vector. With tfp, we can make our latent vector an actual distribution. No trick needed!
  • tfp provides the ability to apply KL-regularization directly on our latent distributions.
  • We can make the reconstructed signal a distribution, and the reconstruction loss is the negative log-likelihood of the input given the reconstruction distribution.

Resources

TF Probability homepage (within link to VAE blog post).

TF Probability example of disentangled VAE from Li and Mandt, ICML 2018.

A Note About TFP Distribution Shapes

  • Event shape describes the shape of a single draw from the distribution; it may be dependent across dimensions. For scalar distributions, the event shape is []. For a bivariate, it is [2], and for a 5-dimensional MultivariateNormal, the event shape is [5].
  • Batch shape describes independent, not identically distributed draws, aka a "batch" of distributions.
  • Sample shape describes independent, identically distributed draws of batches from the distribution family.

Define the latent prior

What distribution do we assume the latent variables should follow? In the case of variational autoencoders we typically assume that the latents are (a) Guassian, (b) have mean=0, and (c) have diagonal covariance. The following demonstrates multiple ways to create such a distribution, with fixed parameters.

Prior - Off-diagonal covariance?

I guess that depends on if the latent posterior variables should be regularized against having off-diagonals. If the goal is for the latents to describe an orthonormal space, at least as much as possible without sacrificing model quality, then the prior should not have any off-diagonal covariances. Indeed, I have never seen a prior that was not set to be diagonal. If the posterior is not allowed off-diagonals then definitely do not put off-diagonals in the prior.

Prior - Fixed or Learnable?

Do we want to enforce these static priors? Or do we want to allow the priors to update as the model trains? The answer depends primarily on if we think that we have a good prior. There is a small discussion about this in the tfp regression tutorial in Case 3:

Note that in this example we are training both P(w) (prior) and Q(w) (posterior). This training corresponds to using Empirical Bayes or Type-II Maximum Likelihood. We used this method so that we wouldn’t need to specify the location of the prior for the slope and intercept parameters, which can be tough to get right if we do not have prior knowledge about the problem. Moreover, if you set the priors very far from their true values, then the posterior may be unduly affected by this choice. A caveat of using Type-II Maximum Likelihood is that you lose some of the regularization benefits over the weights. If you wanted to do a proper Bayesian treatment of uncertainty (if you had some prior knowledge, or a more sophisticated prior), you could use a non-trainable prior (see Appendix B).

The prior is fixed in many of the TFP introductory examples. This is because the intro examples meet conditions under which it doesn't matter if the prior is trainable. See here for more info.


The next cell, in the last example, implements a learnable multivariate normal distribution.

tfp.util.TransformedVariable: Variable tracking object which applies function upon convert_to_tensor

tfp.bijectors.FillScaleTriL: Transforms unconstrained vectors to TriL matrices with positive diagonal.

I have also seen this implemented as a custom class.

Or as a callable that returns another callable to be passed to tfpl.DenseVariational's make_prior_fn argument.

Encoder

The input is transformed into some lower dimensional latent variable. In this case we use a Bidirectional(LSTM). This may not be the right layer for you. We use it here because it is similar to what is used in the disentangled_vae example in the tfp source code.

Then the latent variables are used to parameterize a distribution. It's arguable whether the distribution is part of the encoder, the decoder, or something in between, but we will put it in the encoder.

Off-diagonal covariance in latent?

In the tfp example VAE scripts, all the latents were MultivariateNormalDiag, i.e. no off-diagonal covariances. However, in the VAE with TFP blog post, the encoded latents used MultivariateNormalTriL, and thus were allowed off-diagonals (though the prior did not). Allowing off-diagonals also increases the number of the parameters in the model which might increase the number of samples of data needed.

Mixture of distributions or single distributions?

While this applies to different distribution families as well, we are using Normal distributions. Each independent latent can be modeled as a single Normal or a mixture of Normals. When using a mixture, analytic KL divergence won't work, and more data is required to fit the additional parameters.

I have never used mixture of Gaussians, but the below snippet is a demonstration of how that might work:

tfd.MixtureSameFamily(
      components_distribution=tfd.MultivariateNormalDiag(
          loc=tf.Variable(mixture_components, latent_size),
          scale_diag=tf.nn.softplus(tf.Variable(mixture_components, latent_size))),
      mixture_distribution=tfd.Categorical(logits=Variable(mixture_components)),
      name="prior")

tfp.layers.MixtureNormal(num_components, event_shape=(LATENT_DIM,))

Additional transform on covariance

For the Normal distributions, the initialized value for the mean (loc) is typically centered on 0.0, and the value for the std (scale) is typically centered on 1.0. When these values are changing with training (from previous layer or tf.Variable in the case of a trainable prior), care should be taken so that the learnable variables are centered around their expected initial values and are of similar magnitude. I believe the training machinery works better under these conditions.

For the loc there is nothing to do because it is already centred at 0 and there are no requirements for it to be positive.

For the scale, we want the loss to update values that are by default centered on 0, but when the distribution is sampled, the stddev is centered around 1. Also, we have to be careful that the stddev doesn't go negative. Scale can be transformed to meet the requirements by adding a bias to the scale and transforming it through Softplus. Thus, the inputs to the distribution's scale argument are around 0 (at least initially), then shifted by np.log(np.exp(1) - 1) (=softplus_inverse(1.0) ~ 0.55), then softplus transformed, and finally shifted by 1e-5 (to force > 0) to yield the value that will parameterize the dist stddev.

_loc = tfkl.Dense(LATENT_DIM)(_features)
_scale_diag = tfkl.Dense(LATENT_DIM)(_features)
_scale_diag = _scale_diag + np.log(np.exp(1) - 1)
_scale_diag = tf.math.softplus(_scale_diag) + 1e-5
_static_sample = tfpl.DistributionLambda(
    make_distribution_fn=lambda t: tfd.Independent(
        tfd.Normal(loc=t[0], scale=t[1])
    ),
)([_loc, _scale_diag])

In the case of a trainable prior, we can initialize the tf.TransformedVariable for the scale to be around 1 and use a bijector tf.nn.softplus(_ + np.log(np.exp(1) - 1)) + 1e-5 before sampling. It's a little confusing that the TransformedVariable should be initialized to its transformed value, and the stored variable value (i.e., the one subject to training) is inverse transforming the initialization value through the bijector. See the make_mvn_prior function definition for an example.

Number of samples from distributions?

For any given pass through the model, the distributions can be sampled multiples times. For example, on the output distribution, we can get N_SAMPLES different reconstructions, but then we must calculate the error for each sample (e.g., using 'mse') and take the average error, or we can calculate the probability for each sample and take the average probability: tf.reduce_logsumexp(elbo) - tf.math.log(n_samples).

_static_sample = tfpl.DistributionLambda(
    make_distribution_fn=lambda t: tfd.Independent(tfd.Normal(loc=t[0], scale=t[1])),
    convert_to_tensor_fn=lambda s: s.sample(N_SAMPLES)
)([_loc, _scale_diag])

KL Divergence

The latent posterior distribution is regularized to resemble the prior distribution by penalizing the KL divergence between the posterior and the prior. There are several ways to do this.

  1. Add the latent distribution to the model outputs then use loss functions for each output to penalize KL divergence from the prior. While the reconstructions's loss function will remain -recon_dist.log_prob(expected_recon), the latent dist can use
    1. Analytic KL: lambda _, latent_dist: tfd.kl_divergence(latent_dist, prior)
    2. Quantitative KL: Need a func that accepts (true_x, latent_dist), samples latent_dist to get latent_sample, and returns latent_dist.log_prob(latent_sample) - prior.log_prob(latent_sample).
  2. Add KL regularizer directly to latent distribution (my preferred approach):
    posterior = tfpl.SomeLayer(...,
                               activity_regularizer=tfpl.KLDivergenceRegularizer(
                                   prior, weight=KL_WEIGHT)
                              )
    
  3. Using tfpl.KLDivergenceAddLoss(prior).
    • This currently does not work; see here.
  4. Add KL divergence loss manually
    • kl_loss = tfd.kl_divergence(latent, prior)
    • model.add_loss(weight * kl_loss, axis=0) # Might be necessary to reduce_mean depending on shapes.
  5. Calculate KL loss in custom training calculation.
    • See custom training in this notebook.

Analytic KL (or kl_exact) works only when the latent and prior dists are of the same type, and not for mixture models.

In this notebook I demonstrate 1, 2, and 5. 1 has weight=0 so it isn't actually used, but nevertheless it is still convenient to have Keras print out its values during training. 5 is coded but the custom training loop is commented out. Ultimately, 2 is what is used to update the latent.

Weighting by number of samples

If we allow the KL divergence loss to be weighted too heavily then the model will prioritize matching the prior more than solving the output objective. This is especially problematic when we do not have a learnable prior. I looked to available examples to see what the conventions were. But this left me more confused.

  • In Regression blog post and accompanying google colab: rv_y.variational_loss(y, kl_weight=np.array(batch_size, x.dtype) / x.shape[0])
  • In vae blog post: activity_regularizer=tfpl.KLDivergenceRegularizer(prior, weight=1.0))
  • In vae example script:
    • weighted_elbo = tf.reduce_logsumexp(elbo) - tf.math.log(n_samples)
    • But the above is not used. If it were: loss = -tf.reduce_mean(weighted_elbo)
  • In kernel_divergence_fn kwarg for tfpl.DenseFlipout in logistic_regression example: kernel_divergence_fn=lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(num_samples, dtype=tf.float32)
  • In API docs for tfpl.KLDivergenceRegularizer or tfpl.KLDivergenceAddLoss, example code sets weight=num_train_samples. Isn't this the opposite of the other examples?
  • In disentangled vae example: Not done!
  • In LFADS: "normalized only by batch size, not by dimension or by time steps" - implicit in tf.reduce_mean().

In my confusion I posted a question to the TensorFlow Probability mailing list. Someone answered pointing me to other similar conversations. As best as I can understand, it seems that the conventional scaling to apply to the KL divergence term is (batch_size / number_of_samples). Upon further inspection, I think that depends on which of the above methods of adding KL Divergence loss

KL annealing

Sometimes units can become inactive and stay inactive if they become trapped in a local minimum when the KL distance is near 0. Therefore, it is often beneficial to drastically deprioritize the prior at training outset then gradually increase the KL weight as training progresses. This is well described in section 2.3 of the Ladder VAE paper.

We can alo make the beta term cyclical which provides other benefits as described here.

For this to work with the Keras training syntax, we need the weight to be a non-trainable variable that changes during a callback. You can find an example implementation here

Decoder

Which distribution?

The output distribution should, as much as possible, make sense for the type of data being generated. Some examples:

  • binarized pixels -- independent bernoulli
  • spike counts -- independent Poisson
  • Any aggregate of many small-scale processes -- Normal (Gaussian)

Biological signals sometimes follow Gaussian distributions. When they don't, it's usually a good idea to transform them so that they do, because many data science tools work best with Gaussian data.

For similar reasons, it's quite common to scale the data so that they have a standard deviation of 1.0.

What about covariances? This should be considered separately for every dataset. For the present data, the signals were created by mixing latents, so it is expected that signals with contributions from the same latents will covary, and therefore we should generate outputs with covariances. But in practice it doesn't matter here and it slows down training so we'll go with diagonal-only.

Model

To the model's outputs= kwarg we pass a tuple of both the latent distribution and the output distribution. We do this for one reason only: to monitor the KL divergence loss during training. If you recall, the dataset Y was mapped to (zeros, x) tuple to give us 2 outputs. The compiled model has 2 loss functions: the first calculates the KL divergence of the latent from the prior, and the second calculates the -log-likelihood of the data on the output distribution. The first loss has a weight of 0.0 and is thus not used in updating the model variables. It is not needed because the KL-divergence for updating the model variables is calculated in the activity_regularizer for the prior distribution in the encoder model. Yet, Keras' model.fit will still print out the KL loss (here called "q_z_loss").

If we didn't want to monitor the KL-loss, we could simplify things a little by removing the first output from the model, changing the dataset from outputting (zeros, x) to only output x, removing the first loss-fn from the loss kwarg in model.compile, and getting rid of the loss_weights kwarg.