Skip to content


LearnableMultivariateNormalCell (Model)

Multivariate normal distribution RNN cell.

The model is a RNN-based recurrent function that computes the parameters for a multivariate normal distribution at each timestep t. Based on:

Source code in indl/dists/
class LearnableMultivariateNormalCell(tf.keras.Model):
    """Multivariate normal distribution RNN cell.

    The model is a RNN-based recurrent function that computes the
    parameters for a multivariate normal distribution at each timestep `t`.
    Based on:

    def __init__(self, units: int, out_dim: int,
                 shift_std: float = 0.1, cell_type: str = 'lstm', offdiag: bool = False):
        """Constructs a learnable multivariate normal cell.

          units: Dimensionality of the RNN function parameters.
          out_dim: The dimensionality of the distribution.
          shift_std: Shift applied to MVN std before building the dist. Providing a shift
            toward the expected std allows the input values to be closer to 0.
          cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
          offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
        super(LearnableMultivariateNormalCell, self).__init__()
        self.offdiag = offdiag
        self.output_dimensions = out_dim
        self.units = units
        if cell_type.upper().endswith('LSTM'):
            self.rnn_cell = tfkl.LSTMCell(self.units, implementation=1, name="mvncell")
            # why does the jupyter notebook version require implementation=1 but not in pycharm?
        elif cell_type.upper().endswith('GRU'):
            self.rnn_cell = tfkl.GRUCell(self.units, name="mvnell")
        elif cell_type.upper().endswith('RNN'):
            self.rnn_cell = tfkl.SimpleRNNCell(self.units, name="mvncell")
        elif cell_type.upper().endswith('GRUCLIP'):
            from indl.rnn.gru_clip import GRUClipCell
            self.rnn_cell = GRUClipCell(self.units, name="mvncell")
            raise ValueError("cell_type %s not recognized" % cell_type)

        self.loc_layer = tfkl.Dense(self.output_dimensions, name="mvncell_loc")
        n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(out_dim) - out_dim) if offdiag\
            else (tfpl.IndependentNormal.params_size(out_dim) - out_dim)
        self.scale_untransformed_layer = tfkl.Dense(n_scale_dim, name="mvndiagcell_scale")
        self._scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)

    #def build(self, input_shape):
        #super(LearnableMultivariateNormalDiagCell, self).build(input_shape)
        #self.built = True

    def zero_state(self, sample_batch_shape=()):
        """Returns an initial state for the RNN cell.

          sample_batch_shape: A 0D or 1D tensor of the combined sample and
            batch shape.

          A tuple of the initial previous output at timestep 0 of shape
          [sample_batch_shape, dimensions], and the cell state.
        zero_state = self.rnn_cell.get_initial_state(batch_size=sample_batch_shape[-1], dtype=tf.float32)
        sample_batch_shape = tf.convert_to_tensor(value=sample_batch_shape, dtype=tf.int32)
        out_shape = tf.concat((sample_batch_shape, [self.output_dimensions]), axis=-1)
        previous_output = tf.zeros(out_shape)
        return previous_output, zero_state

    def call(self, inputs, state):
        """Runs the model to generate a distribution for a single timestep.

        This generates a batched MultivariateNormalDiag distribution using
        the output of the recurrent model at the current timestep to
        parameterize the distribution.

          inputs: The sampled value of `z` at the previous timestep, i.e.,
            `z_{t-1}`, of shape [..., dimensions].
            `z_0` should be set to the empty matrix.
          state: A tuple containing the (hidden, cell) state.

          A tuple of a MultivariateNormalDiag distribution, and the state of
          the recurrent function at the end of the current timestep. The
          distribution will have event shape [dimensions], batch shape
          [...], and sample shape [sample_shape, ..., dimensions].
        # In order to allow the user to pass in a single example without a batch
        # dimension, we always expand the input to at least two dimensions, then
        # fix the output shape to remove the batch dimension if necessary.
        original_shape = inputs.shape
        if len(original_shape) < 2:
            inputs = tf.reshape(inputs, [1, -1])
        out, state = self.rnn_cell(inputs, state)
        parms_shape = tf.concat((original_shape[:-1], [self.output_dimensions]), 0)
        loc = tf.reshape(self.loc_layer(out), parms_shape)
        scale = self.scale_untransformed_layer(out)
        scale = tf.nn.softplus(scale + self._scale_shift) + 1e-5
        scale = tf.reshape(scale, parms_shape)
        if self.offdiag:
            return tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
            return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale), state

__init__(self, units, out_dim, shift_std=0.1, cell_type='lstm', offdiag=False) special

Constructs a learnable multivariate normal cell.


Name Type Description Default
units int

Dimensionality of the RNN function parameters.

out_dim int

The dimensionality of the distribution.

shift_std float

Shift applied to MVN std before building the dist. Providing a shift toward the expected std allows the input values to be closer to 0.

cell_type str

an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.

offdiag bool

set True to allow non-zero covariance (within-timestep) in the returned distribution.

Source code in indl/dists/
def __init__(self, units: int, out_dim: int,
             shift_std: float = 0.1, cell_type: str = 'lstm', offdiag: bool = False):
    """Constructs a learnable multivariate normal cell.

      units: Dimensionality of the RNN function parameters.
      out_dim: The dimensionality of the distribution.
      shift_std: Shift applied to MVN std before building the dist. Providing a shift
        toward the expected std allows the input values to be closer to 0.
      cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
      offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
    super(LearnableMultivariateNormalCell, self).__init__()
    self.offdiag = offdiag
    self.output_dimensions = out_dim
    self.units = units
    if cell_type.upper().endswith('LSTM'):
        self.rnn_cell = tfkl.LSTMCell(self.units, implementation=1, name="mvncell")
        # why does the jupyter notebook version require implementation=1 but not in pycharm?
    elif cell_type.upper().endswith('GRU'):
        self.rnn_cell = tfkl.GRUCell(self.units, name="mvnell")
    elif cell_type.upper().endswith('RNN'):
        self.rnn_cell = tfkl.SimpleRNNCell(self.units, name="mvncell")
    elif cell_type.upper().endswith('GRUCLIP'):
        from indl.rnn.gru_clip import GRUClipCell
        self.rnn_cell = GRUClipCell(self.units, name="mvncell")
        raise ValueError("cell_type %s not recognized" % cell_type)

    self.loc_layer = tfkl.Dense(self.output_dimensions, name="mvncell_loc")
    n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(out_dim) - out_dim) if offdiag\
        else (tfpl.IndependentNormal.params_size(out_dim) - out_dim)
    self.scale_untransformed_layer = tfkl.Dense(n_scale_dim, name="mvndiagcell_scale")
    self._scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)

#def build(self, input_shape):
    #super(LearnableMultivariateNormalDiagCell, self).build(input_shape)
    #self.built = True

call(self, inputs, state)

Runs the model to generate a distribution for a single timestep.

This generates a batched MultivariateNormalDiag distribution using the output of the recurrent model at the current timestep to parameterize the distribution.


Name Type Description Default

The sampled value of z at the previous timestep, i.e., z_{t-1}, of shape [..., dimensions]. z_0 should be set to the empty matrix.


A tuple containing the (hidden, cell) state.



Type Description

A tuple of a MultivariateNormalDiag distribution, and the state of the recurrent function at the end of the current timestep. The distribution will have event shape [dimensions], batch shape [...], and sample shape [sample_shape, ..., dimensions].

Source code in indl/dists/
def call(self, inputs, state):
    """Runs the model to generate a distribution for a single timestep.

    This generates a batched MultivariateNormalDiag distribution using
    the output of the recurrent model at the current timestep to
    parameterize the distribution.

      inputs: The sampled value of `z` at the previous timestep, i.e.,
        `z_{t-1}`, of shape [..., dimensions].
        `z_0` should be set to the empty matrix.
      state: A tuple containing the (hidden, cell) state.

      A tuple of a MultivariateNormalDiag distribution, and the state of
      the recurrent function at the end of the current timestep. The
      distribution will have event shape [dimensions], batch shape
      [...], and sample shape [sample_shape, ..., dimensions].
    # In order to allow the user to pass in a single example without a batch
    # dimension, we always expand the input to at least two dimensions, then
    # fix the output shape to remove the batch dimension if necessary.
    original_shape = inputs.shape
    if len(original_shape) < 2:
        inputs = tf.reshape(inputs, [1, -1])
    out, state = self.rnn_cell(inputs, state)
    parms_shape = tf.concat((original_shape[:-1], [self.output_dimensions]), 0)
    loc = tf.reshape(self.loc_layer(out), parms_shape)
    scale = self.scale_untransformed_layer(out)
    scale = tf.nn.softplus(scale + self._scale_shift) + 1e-5
    scale = tf.reshape(scale, parms_shape)
    if self.offdiag:
        return tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
        return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale), state

zero_state(self, sample_batch_shape=())

Returns an initial state for the RNN cell.


Name Type Description Default

A 0D or 1D tensor of the combined sample and batch shape.



Type Description

A tuple of the initial previous output at timestep 0 of shape [sample_batch_shape, dimensions], and the cell state.

Source code in indl/dists/
def zero_state(self, sample_batch_shape=()):
    """Returns an initial state for the RNN cell.

      sample_batch_shape: A 0D or 1D tensor of the combined sample and
        batch shape.

      A tuple of the initial previous output at timestep 0 of shape
      [sample_batch_shape, dimensions], and the cell state.
    zero_state = self.rnn_cell.get_initial_state(batch_size=sample_batch_shape[-1], dtype=tf.float32)
    sample_batch_shape = tf.convert_to_tensor(value=sample_batch_shape, dtype=tf.int32)
    out_shape = tf.concat((sample_batch_shape, [self.output_dimensions]), axis=-1)
    previous_output = tf.zeros(out_shape)
    return previous_output, zero_state

LearnableMultivariateNormalDiag (Model)

Learnable multivariate diagonal normal distribution.

The model is a multivariate normal distribution with learnable mean and stddev parameters.

See make_mvn_prior for a description.

Source code in indl/dists/
class LearnableMultivariateNormalDiag(tf.keras.Model):
    """Learnable multivariate diagonal normal distribution.

    The model is a multivariate normal distribution with learnable
    `mean` and `stddev` parameters.

    See make_mvn_prior for a description.

    def __init__(self, dimensions, init_std=1.0, trainable_mean=True, trainable_var=True):
        """Constructs a learnable multivariate diagonal normal model.

          dimensions: An integer corresponding to the dimensionality of the
        super(LearnableMultivariateNormalDiag, self).__init__()

        with tf.name_scope(self._name):
            self.dimensions = dimensions
            if trainable_mean:
                self._mean = tf.Variable(tf.random.normal([dimensions], stddev=0.1), name="mean")
                self._mean = tf.zeros(dimensions)
            if trainable_var:
                _scale_shift = np.log(np.exp(init_std) - 1).astype(np.float32)
                self._scale = tfp.util.TransformedVariable(
                    tf.random.normal([dimensions], mean=init_std, stddev=init_std/10, dtype=tf.float32),
                    bijector=tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(_scale_shift)]),
                self._scale = init_std * tf.ones(dimensions)

    def __call__(self, *args, **kwargs):
        # Allow this Model to be called without inputs.
        dummy = tf.zeros(self.dimensions)
        return super(LearnableMultivariateNormalDiag, self).__call__(
            dummy, *args, **kwargs)

    def call(self, inputs):
        """Runs the model to generate multivariate normal distribution.

          inputs: Unused.

          A MultivariateNormalDiag distribution with event shape
          [dimensions], batch shape [], and sample shape [sample_shape,
        del inputs  # unused
        with tf.name_scope(self._name):
            return tfd.MultivariateNormalDiag(loc=self.loc, scale_diag=self.scale_diag)

    def loc(self):
        """The mean of the normal distribution."""
        return self._mean

    def scale_diag(self):
        """The diagonal standard deviation of the normal distribution."""
        return self._scale

loc property readonly

The mean of the normal distribution.

scale_diag property readonly

The diagonal standard deviation of the normal distribution.

__init__(self, dimensions, init_std=1.0, trainable_mean=True, trainable_var=True) special

Constructs a learnable multivariate diagonal normal model.


Name Type Description Default

An integer corresponding to the dimensionality of the distribution.

Source code in indl/dists/
def __init__(self, dimensions, init_std=1.0, trainable_mean=True, trainable_var=True):
    """Constructs a learnable multivariate diagonal normal model.

      dimensions: An integer corresponding to the dimensionality of the
    super(LearnableMultivariateNormalDiag, self).__init__()

    with tf.name_scope(self._name):
        self.dimensions = dimensions
        if trainable_mean:
            self._mean = tf.Variable(tf.random.normal([dimensions], stddev=0.1), name="mean")
            self._mean = tf.zeros(dimensions)
        if trainable_var:
            _scale_shift = np.log(np.exp(init_std) - 1).astype(np.float32)
            self._scale = tfp.util.TransformedVariable(
                tf.random.normal([dimensions], mean=init_std, stddev=init_std/10, dtype=tf.float32),
                bijector=tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(_scale_shift)]),
            self._scale = init_std * tf.ones(dimensions)

call(self, inputs)

Runs the model to generate multivariate normal distribution.


Name Type Description Default




Type Description

A MultivariateNormalDiag distribution with event shape [dimensions], batch shape [], and sample shape [sample_shape, dimensions].

Source code in indl/dists/
def call(self, inputs):
    """Runs the model to generate multivariate normal distribution.

      inputs: Unused.

      A MultivariateNormalDiag distribution with event shape
      [dimensions], batch shape [], and sample shape [sample_shape,
    del inputs  # unused
    with tf.name_scope(self._name):
        return tfd.MultivariateNormalDiag(loc=self.loc, scale_diag=self.scale_diag)

make_learnable_mvn_params(ndim, init_std=1.0, trainable_mean=True, trainable_var=True, offdiag=False)

Return mean (loc) and stddev (scale) parameters for initializing multivariate normal distributions. If trainable_mean then it will be initialized with random normal (stddev=0.1), otherwise zeros. If trainable_var then it will be initialized with random normal centered at a value such that the bijector transformation yields the value in init_std. When init_std is 1.0 (default) then the inverse-bijected value is approximately 0.0. If not trainable_var then scale is a vector or matrix of init_std of appropriate shape for the dist.


Name Type Description Default
ndim int

Number of dimensions.

init_std float

Initial value for the standard deviation.

trainable_mean bool

Whether or not the mean (loc) is a trainable tf.Variable.

trainable_var bool

Whether or not the variance (scale) is a trainable tf.Variable.

offdiag bool

Whether or not off-diagonal elements are allowed.


Returns: loc, scale

Source code in indl/dists/
def make_learnable_mvn_params(ndim: int, init_std: float = 1.0, trainable_mean: bool = True, trainable_var: bool = True,
                              offdiag: bool = False):
    Return mean (loc) and stddev (scale) parameters for initializing multivariate normal distributions.
    If trainable_mean then it will be initialized with random normal (stddev=0.1), otherwise zeros.
    If trainable_var then it will be initialized with random normal centered at a value such that the
        bijector transformation yields the value in init_std. When init_std is 1.0 (default) then the
        inverse-bijected value is approximately 0.0.
        If not trainable_var then scale is a vector or matrix of init_std of appropriate shape for the dist.
        ndim: Number of dimensions.
        init_std: Initial value for the standard deviation.
        trainable_mean: Whether or not the mean (loc) is a trainable tf.Variable.
        trainable_var: Whether or not the variance (scale) is a trainable tf.Variable.
        offdiag: Whether or not off-diagonal elements are allowed.

    Returns: loc, scale
    if trainable_mean:
        loc = tf.Variable(tf.random.normal([ndim], stddev=0.1, dtype=tf.float32))
        loc = tf.zeros(ndim)
    # Initialize the variance (scale), trainable or not, offdiag or not.
    if trainable_var:
        if offdiag:
            _ndim = [ndim, ndim]
            scale = tfp.util.TransformedVariable(
                # init_std * tf.eye(ndim, dtype=tf.float32),
                tf.random.normal(_ndim, mean=init_std, stddev=init_std/10, dtype=tf.float32),
            _scale_shift = np.log(np.exp(init_std) - 1).astype(np.float32)  # tfp.math.softplus_inverse(init_std)
            scale = tfp.util.TransformedVariable(
                # init_std * tf.ones(ndim, dtype=tf.float32),
                tf.random.normal([ndim], mean=init_std, stddev=init_std/10, dtype=tf.float32),
                tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(_scale_shift)]),
        if offdiag:
            scale = init_std * tf.eye(ndim)
            scale = init_std * tf.ones(ndim)

    return loc, scale

make_mvn_dist_fn(_x_, ndim, shift_std=1.0, offdiag=False, loc_name=None, scale_name=None, use_mvn_diag=True)

Take a 1-D tensor and use it to parameterize a MVN dist. This doesn't return the distribution, but the function to make the distribution and its arguments. make_dist_fn, [loc, scale] You can supply it to tfpl.DistributionLambda


Name Type Description Default
_x_ Tensor required
ndim int required
shift_std float 1.0
offdiag bool False
loc_name Optional[str] None
scale_name Optional[str] None
use_mvn_diag bool True


Type Description
Tuple[Callable[[tensorflow.python.framework.ops.Tensor, tensorflow.python.framework.ops.Tensor], tensorflow_probability.python.distributions.distribution.Distribution], List[tensorflow.python.framework.ops.Tensor]]

make_dist_fn, [loc, scale]

Source code in indl/dists/
def make_mvn_dist_fn(_x_: tf.Tensor, ndim: int, shift_std: float = 1.0, offdiag: bool = False,
                     loc_name: Optional[str] = None, scale_name: Optional[str] = None, use_mvn_diag: bool = True
                     ) -> Tuple[Callable[[tf.Tensor, tf.Tensor], tfd.Distribution], List[tf.Tensor]]:
    Take a 1-D tensor and use it to parameterize a MVN dist.
    This doesn't return the distribution, but the function to make the distribution and its arguments.
    make_dist_fn, [loc, scale]
    You can supply it to tfpl.DistributionLambda


        make_dist_fn, [loc, scale]

    _scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)
    _loc = tfkl.Dense(ndim, name=loc_name)(_x_)

    n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(ndim) - ndim) if offdiag\
        else (tfpl.IndependentNormal.params_size(ndim) - ndim)
    _scale = tfkl.Dense(n_scale_dim, name=scale_name)(_x_)
    _scale = tf.math.softplus(_scale + _scale_shift) + 1e-5
    if offdiag:
        _scale = tfb.FillTriangular()(_scale)
        make_dist_fn = lambda t: tfd.MultivariateNormalTriL(loc=t[0], scale_tril=t[1])
        if use_mvn_diag:  # Match type with prior
            make_dist_fn = lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1])
            make_dist_fn = lambda t: tfd.Independent(tfd.Normal(loc=t[0], scale=t[1]))
    return make_dist_fn, [_loc, _scale]

make_mvn_prior(ndim, init_std=1.0, trainable_mean=True, trainable_var=True, offdiag=False)

Creates a tensorflow-probability distribution: MultivariateNormalTriL if offdiag else MultivariateNormalDiag Mean (loc) and sigma (scale) can be trainable or not. Mean initializes to random.normal around 0 (stddev=0.1) if trainable, else zeros. Scale initialies to init_std if not trainable. If it is trainable, it initializes to a tfp TransformedVariable that will be centered at 0 for easy training under the hood, but will be transformed via softplus to give something initially close to init_var.

loc and scale are tracked by the MVNDiag class.

For LFADS ics prior, trainable_mean=True, trainable_var=False For LFADS cos prior (if not using AR1), trainable_mean=False, trainable_var=True In either case, var was initialized with 0.1 (==> logvar with log(0.1)) Unlike the LFADS' LearnableDiagonalGaussian, here we don't support multi-dimensional, just a vector.

See also LearnableMultivariateNormalDiag for a tf.keras.Model version of this.


Name Type Description Default
ndim int

latent dimension of distribution. Currently only supports 1 d (I think )

init_std float

initial standard deviation of the gaussian. If trainable_var then the initial standard deviation will be drawn from a random.normal distribution with mean init_std and stddev 1/10th of that.

trainable_mean bool

If the mean should be a tf.Variable

trainable_var bool

If the variance/stddev/scale (whatever you call it) should be a tf.Variable

offdiag bool

If the variance-covariance matrix is allowed non-zero off-diagonal elements.



Type Description
Union[tensorflow_probability.python.distributions.mvn_diag.MultivariateNormalDiag, tensorflow_probability.python.distributions.mvn_tril.MultivariateNormalTriL]

A tensorflow-probability distribution (either MultivariateNormalTriL or MultivariateNormalDiag).

Source code in indl/dists/
def make_mvn_prior(ndim: int, init_std: float = 1.0, trainable_mean: bool = True, trainable_var: bool = True,
                   offdiag: bool = False) -> Union[tfd.MultivariateNormalDiag, tfd.MultivariateNormalTriL]:
    Creates a tensorflow-probability distribution:
        MultivariateNormalTriL if offdiag else MultivariateNormalDiag
    Mean (loc) and sigma (scale) can be trainable or not.
    Mean initializes to random.normal around 0 (stddev=0.1) if trainable, else zeros.
    Scale initialies to init_std if not trainable. If it is trainable, it initializes
     to a tfp TransformedVariable that will be centered at 0 for easy training under the hood,
     but will be transformed via softplus to give something initially close to init_var.

    loc and scale are tracked by the MVNDiag class.

    For LFADS ics prior, trainable_mean=True, trainable_var=False
    For LFADS cos prior (if not using AR1), trainable_mean=False, trainable_var=True
    In either case, var was initialized with 0.1 (==> logvar with log(0.1))
    Unlike the LFADS' LearnableDiagonalGaussian, here we don't support multi-dimensional, just a vector.

    See also LearnableMultivariateNormalDiag for a tf.keras.Model version of this.

        ndim: latent dimension of distribution. Currently only supports 1 d (I think )
        init_std: initial standard deviation of the gaussian. If trainable_var then the initial standard deviation
            will be drawn from a random.normal distribution with mean init_std and stddev 1/10th of that.
        trainable_mean: If the mean should be a tf.Variable
        trainable_var: If the variance/stddev/scale (whatever you call it) should be a tf.Variable
        offdiag: If the variance-covariance matrix is allowed non-zero off-diagonal elements.

        A tensorflow-probability distribution (either MultivariateNormalTriL or MultivariateNormalDiag).
    loc, scale = make_learnable_mvn_params(ndim, init_std=init_std,
                                           trainable_mean=trainable_mean, trainable_var=trainable_var,

    # Initialize the prior.
    if offdiag:
        # Note: Diag must be > 0, upper triangular must be 0, and lower triangular may be != 0.
        prior = tfd.MultivariateNormalTriL(
        prior = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
        # kl_exact needs same dist types for prior and latent.
        # We would switch to the next line if we switched our latent to using tf.Independent(tfd.Normal)
        # prior = tfd.Independent(tfd.Normal(loc=tf.zeros(ndim), scale=1), reinterpreted_batch_ndims=1)

    return prior

make_variational(x, dist_dim, init_std=1.0, offdiag=False, samps=1, loc_name='loc', scale_name='scale', dist_name='q', use_mvn_diag=True)

Take an input tensor and return a multivariate normal distribution parameterized by that input tensor.


Name Type Description Default
x Tensor

input tensor

dist_dim int

the dimensionality of the distribution

init_std float

initial stddev SHIFT of the distribution when input is 0.

offdiag bool

whether or not to include covariances

samps int

the number of samples to draw when using implied convert_to_tensor_fn


not used (I need to handle naming better)


not used


not used

use_mvn_diag bool

whether to use tfd.MultivariateNormal(Diag|TriL) (True) or tfd.Independent(tfd.Normal) (False) Latter is untested. Note that the mvn dists will put the timesteps dimension (if present in the input) into the "batch dimension" while the "event" dimension will be the last dimension only. You can use tfd.Independent(q_dist, reinterpreted_batch_ndims=1) to move the timestep dimension to the event dimension if necessary. (tfd.Independent doesn't play well with tf.keras.Model inputs/outputs).



Type Description
Union[tensorflow_probability.python.distributions.mvn_diag.MultivariateNormalDiag, tensorflow_probability.python.distributions.mvn_tril.MultivariateNormalTriL, tensorflow_probability.python.distributions.independent.Independent]

A tfd.Distribution. The distribution is of type MultivariateNormalDiag (or MultivariateNormalTriL if offdiag) if use_mvn_diag is set.

Source code in indl/dists/
def make_variational(x: tf.Tensor, dist_dim: int,
                     init_std: float = 1.0, offdiag: bool = False,
                     samps: int = 1,
                     loc_name="loc", scale_name="scale",
                     use_mvn_diag: bool = True
                     ) -> Union[tfd.MultivariateNormalDiag, tfd.MultivariateNormalTriL, tfd.Independent]:
    Take an input tensor and return a multivariate normal distribution parameterized by that input tensor.

        x: input tensor
        dist_dim: the dimensionality of the distribution
        init_std: initial stddev SHIFT of the distribution when input is 0.
        offdiag: whether or not to include covariances
        samps: the number of samples to draw when using implied convert_to_tensor_fn
        loc_name: not used (I need to handle naming better)
        scale_name: not used
        dist_name: not used
        use_mvn_diag: whether to use tfd.MultivariateNormal(Diag|TriL) (True) or tfd.Independent(tfd.Normal) (False)
            Latter is untested. Note that the mvn dists will put the timesteps dimension (if present in the input)
            into the "batch dimension" while the "event" dimension will be the last dimension only.
            You can use tfd.Independent(q_dist, reinterpreted_batch_ndims=1) to move the timestep dimension to the
            event dimension if necessary. (tfd.Independent doesn't play well with tf.keras.Model inputs/outputs).

        A tfd.Distribution. The distribution is of type MultivariateNormalDiag (or MultivariateNormalTriL if offdiag)
        if use_mvn_diag is set.
    make_dist_fn, dist_params = make_mvn_dist_fn(x, dist_dim, shift_std=init_std,
                                                 # loc_name=loc_name, scale_name=scale_name,
    # Python `callable` that takes a `tfd.Distribution`
    #         instance and returns a `tf.Tensor`-like object.

    # Unfortunately I couldn't get this to work :(
    # Will have to explicitly q_f.value() | qf.mean() from dist in train_step
    def custom_convert_fn(d, training=None):
        if training is None:
            training = K.learning_phase()
        output = tf_utils.smart_cond(training,
                                     lambda: d.sample(samps),
                                     lambda: d.mean()
        return output
    def convert_fn(d):
        return K.in_train_phase(tfd.Distribution.sample if samps <= 1 else lambda: d.sample(samps),
        lambda: d.mean())
    convert_fn = tfd.Distribution.sample if samps <= 1 else lambda d: d.sample(samps)
    q_dist = tfpl.DistributionLambda(make_distribution_fn=make_dist_fn,
    # if tf.shape(x).shape[0] > 2:
    #     q_dist = tfd.Independent(q_dist, reinterpreted_batch_ndims=1)
    return q_dist


AR1ProcessMVNGenerator (IProcessMVNGenerator)

Similar to LFADS' LearnableAutoRegressive1Prior. Here we use the terminology from:

The autoregressive function takes the form: E(X_t) = E(c) + phi * E(X_{t-1}) + e_t E(c) is a constant. phi is a parameter, which is equivalent to exp(-1/tau) = exp(-exp(-logtau)). where tau is a time constant. e_t is white noise with zero-mean with evar = sigma_e**2

When there's no previous sample, E(X_t) = E(c) + e_t, which is a draw from N(c, sigma_e2) When there is a previous sample, E(X_t) = E(c) + phi * E(X_{t-1}) + e_t, which means a draw from N(c + phi * X_{t-1}, sigma_p2) where sigma_p2 = phi2 * var(X_{t-1}) + sigma_e2 = sigma_e2 / (1 - phi**2) or logpvar = logevar - (log(1 - phi) + log(1 + phi))

Note that this could be roughly equivalent to tfd.Autoregressive if it was passed a distribution_fn with the same transition.

See issue:

Source code in indl/dists/
class AR1ProcessMVNGenerator(IProcessMVNGenerator):
    Similar to LFADS' LearnableAutoRegressive1Prior.
    Here we use the terminology from:

    The autoregressive function takes the form:
        E(X_t) = E(c) + phi * E(X_{t-1}) + e_t
    E(c) is a constant.
    phi is a parameter, which is equivalent to exp(-1/tau) = exp(-exp(-logtau)).
        where tau is a time constant.
    e_t is white noise with zero-mean with evar = sigma_e**2

    When there's no previous sample, E(X_t) = E(c) + e_t,
        which is a draw from N(c, sigma_e**2)
    When there is a previous sample, E(X_t) = E(c) + phi * E(X_{t-1}) + e_t,
        which means a draw from N(c + phi * X_{t-1}, sigma_p**2)
        where sigma_p**2 = phi**2 * var(X_{t-1}) + sigma_e**2 = sigma_e**2 / (1 - phi**2)
        or logpvar = logevar - (log(1 - phi) + log(1 + phi))

    Note that this could be roughly equivalent to tfd.Autoregressive if it was passed
    a `distribution_fn` with the same transition.

    See issue:
    def __init__(self, init_taus: Union[float, List[float]],
                 init_std: Union[float, List[float]] = 0.1,
                 trainable_mean: bool = False,
                 trainable_tau: bool = True,
                 trainable_var: bool = True,
                 offdiag: bool = False):

            init_taus: Initial values of tau
            init_std: Initial value of sigma_e
            trainable_mean: set True if the mean (e_c) is trainable.
            trainable_tau: set True to
        self._offdiag = offdiag
        if isinstance(init_taus, float):
            init_taus = [init_taus]
        # TODO: Add time axis for easier broadcasting
        ndim = len(init_taus)
        self._e_c, self._e_scale = make_learnable_mvn_params(ndim, init_std=init_std,
        self._logtau = tf.Variable(tf.math.log(init_taus), dtype=tf.float32, trainable=trainable_tau)
        self._phi = tf.exp(-tf.exp(-self._logtau))
        self._p_scale = tf.exp(tf.math.log(self._e_scale) - (tf.math.log(1 - self._phi) + tf.math.log(1 + self._phi)))

    def get_dist(self, timesteps, samples=1, batch_size=1, fixed=False):
        locs = []
        scales = []
        sample_list = []

        # Add a time dimension
        e_c = tf.expand_dims(self._e_c, 0)
        e_scale = tf.expand_dims(self._e_scale, 0)
        p_scale = tf.expand_dims(self._p_scale, 0)

        sample = tf.expand_dims(tf.expand_dims(tf.zeros_like(e_c), 0), 0)
        sample = tf.tile(sample, [samples, batch_size, 1, 1])
        for _ in range(timesteps):
            loc = e_c + self._phi * sample
            scale = p_scale if _ > 0 else e_scale
            if self._offdiag:
                dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
                dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
            sample = dist.sample()

        sample = tf.concat(sample_list, axis=2)
        loc = tf.concat(locs, axis=2)
        scale = tf.concat(scales, axis=-2)
        if self._offdiag:
            dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
            dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)

        dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
        return sample, dist

__init__(self, init_taus, init_std=0.1, trainable_mean=False, trainable_tau=True, trainable_var=True, offdiag=False) special


Name Type Description Default
init_taus Union[float, List[float]]

Initial values of tau

init_std Union[float, List[float]]

Initial value of sigma_e

trainable_mean bool

set True if the mean (e_c) is trainable.

trainable_tau bool

set True to

trainable_nvar required
Source code in indl/dists/
def __init__(self, init_taus: Union[float, List[float]],
             init_std: Union[float, List[float]] = 0.1,
             trainable_mean: bool = False,
             trainable_tau: bool = True,
             trainable_var: bool = True,
             offdiag: bool = False):

        init_taus: Initial values of tau
        init_std: Initial value of sigma_e
        trainable_mean: set True if the mean (e_c) is trainable.
        trainable_tau: set True to
    self._offdiag = offdiag
    if isinstance(init_taus, float):
        init_taus = [init_taus]
    # TODO: Add time axis for easier broadcasting
    ndim = len(init_taus)
    self._e_c, self._e_scale = make_learnable_mvn_params(ndim, init_std=init_std,
    self._logtau = tf.Variable(tf.math.log(init_taus), dtype=tf.float32, trainable=trainable_tau)
    self._phi = tf.exp(-tf.exp(-self._logtau))
    self._p_scale = tf.exp(tf.math.log(self._e_scale) - (tf.math.log(1 - self._phi) + tf.math.log(1 + self._phi)))

RNNMVNGenerator (IProcessMVNGenerator)

Similar to DSAE's LearnableMultivariateNormalDiagCell

Source code in indl/dists/
class RNNMVNGenerator(IProcessMVNGenerator):
    Similar to DSAE's LearnableMultivariateNormalDiagCell
    def __init__(self, units: int, out_dim: int, cell_type: str, shift_std: float = 0.1, offdiag: bool = False):

            units: Dimensionality of the RNN function parameters.
            out_dim: The dimensionality of the distribution.
            cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
            shift_std: Shift applied to MVN std before building the dist. Providing a shift
                toward the expected std allows the input values to be closer to 0.
            offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
        self.cell = LearnableMultivariateNormalCell(units, out_dim, cell_type=cell_type,
                                                    shift_std=shift_std, offdiag=offdiag)

    def get_dist(self, timesteps, samples=1, batch_size=1, fixed=True):
        Samples from self.cell `timesteps` times.
        On each step, the previous (sample, state) is fed back into the cell
        (zero_state used for 0th step).

        The cell returns a multivariate normal diagonal distribution for each timestep.
        We collect each timestep-dist's params (loc and scale), then use them to create
        the return value: a single MVN diag dist that has a dimension for timesteps.

        The cell returns a full dist for each timestep so that we can 'sample' it.
        If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
        to doing a generative RNN (init state = zeros, return_sequences=True) then passing
        those values through a pair of Dense layers to parameterize a single MVNDiag.

            timesteps: Number of times to sample from the dynamic_prior_cell. Output will have
            samples: Number of samples to draw from the latent distribution.
            batch_size: Number of sequences to sample.
            fixed: Boolean for whether or not to share the same random
                    sample across all sequences in batch.

        if fixed:
            sample_batch_size = 1
            sample_batch_size = batch_size

        sample, state = self.cell.zero_state([samples, sample_batch_size])
        locs = []
        scales = []
        sample_list = []
        scale_parm_name = "scale_tril" if self.cell.offdiag else "scale_diag"  # TODO: Check this for offdiag
        for _ in range(timesteps):
            dist, state = self.cell(sample, state)
            sample = dist.sample()

        sample = tf.stack(sample_list, axis=2)
        loc = tf.stack(locs, axis=2)
        scale = tf.stack(scales, axis=2)

        if fixed:  # tile along the batch axis
            sample = sample + tf.zeros([batch_size, 1, 1])

        if self.cell.offdiag:
            dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
            dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)

        dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
        return sample, dist

__init__(self, units, out_dim, cell_type, shift_std=0.1, offdiag=False) special


Name Type Description Default
units int

Dimensionality of the RNN function parameters.

out_dim int

The dimensionality of the distribution.

cell_type str

an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.

shift_std float

Shift applied to MVN std before building the dist. Providing a shift toward the expected std allows the input values to be closer to 0.

offdiag bool

set True to allow non-zero covariance (within-timestep) in the returned distribution.

Source code in indl/dists/
def __init__(self, units: int, out_dim: int, cell_type: str, shift_std: float = 0.1, offdiag: bool = False):

        units: Dimensionality of the RNN function parameters.
        out_dim: The dimensionality of the distribution.
        cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
        shift_std: Shift applied to MVN std before building the dist. Providing a shift
            toward the expected std allows the input values to be closer to 0.
        offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
    self.cell = LearnableMultivariateNormalCell(units, out_dim, cell_type=cell_type,
                                                shift_std=shift_std, offdiag=offdiag)

get_dist(self, timesteps, samples=1, batch_size=1, fixed=True)

    Samples from self.cell `timesteps` times.
    On each step, the previous (sample, state) is fed back into the cell
    (zero_state used for 0th step).

    The cell returns a multivariate normal diagonal distribution for each timestep.
    We collect each timestep-dist's params (loc and scale), then use them to create
    the return value: a single MVN diag dist that has a dimension for timesteps.

    The cell returns a full dist for each timestep so that we can 'sample' it.
    If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
    to doing a generative RNN (init state = zeros, return_sequences=True) then passing
    those values through a pair of Dense layers to parameterize a single MVNDiag.

    !!! args
        timesteps: Number of times to sample from the dynamic_prior_cell. Output will have
        samples: Number of samples to draw from the latent distribution.
        batch_size: Number of sequences to sample.
        !!! fixed "Boolean for whether or not to share the same random"
                sample across all sequences in batch. Returns:

Source code in indl/dists/
    def get_dist(self, timesteps, samples=1, batch_size=1, fixed=True):
        Samples from self.cell `timesteps` times.
        On each step, the previous (sample, state) is fed back into the cell
        (zero_state used for 0th step).

        The cell returns a multivariate normal diagonal distribution for each timestep.
        We collect each timestep-dist's params (loc and scale), then use them to create
        the return value: a single MVN diag dist that has a dimension for timesteps.

        The cell returns a full dist for each timestep so that we can 'sample' it.
        If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
        to doing a generative RNN (init state = zeros, return_sequences=True) then passing
        those values through a pair of Dense layers to parameterize a single MVNDiag.

            timesteps: Number of times to sample from the dynamic_prior_cell. Output will have
            samples: Number of samples to draw from the latent distribution.
            batch_size: Number of sequences to sample.
            fixed: Boolean for whether or not to share the same random
                    sample across all sequences in batch.

        if fixed:
            sample_batch_size = 1
            sample_batch_size = batch_size

        sample, state = self.cell.zero_state([samples, sample_batch_size])
        locs = []
        scales = []
        sample_list = []
        scale_parm_name = "scale_tril" if self.cell.offdiag else "scale_diag"  # TODO: Check this for offdiag
        for _ in range(timesteps):
            dist, state = self.cell(sample, state)
            sample = dist.sample()

        sample = tf.stack(sample_list, axis=2)
        loc = tf.stack(locs, axis=2)
        scale = tf.stack(scales, axis=2)

        if fixed:  # tile along the batch axis
            sample = sample + tf.zeros([batch_size, 1, 1])

        if self.cell.offdiag:
            dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
            dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)

        dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
        return sample, dist

RNNMultivariateNormalDiag (MultivariateNormalDiag)

Source code in indl/dists/
class RNNMultivariateNormalDiag(tfd.MultivariateNormalDiag):
    def __init__(self, cell, n_timesteps=1, output_dim=None, name="rnn_mvn_diag", **kwargs):
        self.cell = cell
        if output_dim is not None and hasattr(self.cell, 'output_dim'):
            self.cell.output_dim = output_dim
        if hasattr(self.cell, 'output_dim'):
            output_dim = self.cell.output_dim
            output_dim = output_dim or self.cell.units

        h0 = tf.zeros([1, self.cell.units])
        c0 = tf.zeros([1, self.cell.units])
        input0 = tf.zeros((1, output_dim))

        if hasattr(cell, 'reset_dropout_mask'):

        input_ = input0
        states_ = (h0, c0)
        successive_outputs = []
        for i in range(n_timesteps):
            input_, states_ = self.cell(input_, states_)

        loc = tf.concat([_.parameters["distribution"].parameters["loc"]
                        for _ in successive_outputs],
        scale_diag = tf.concat([_.parameters["distribution"].parameters["scale_diag"]
                               for _ in successive_outputs],

        super(RNNMultivariateNormalDiag, self).__init__(loc=loc, scale_diag=scale_diag, name=name, **kwargs)

cross_entropy(self, other, name='cross_entropy')

Computes the (Shannon) cross entropy.

Denote this distribution (self) by P and the other distribution by Q. Assuming P, Q are absolutely continuous with respect to one another and permit densities p(x) dr(x) and q(x) dr(x), (Shannon) cross entropy is defined as:

H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)

where F denotes the support of the random variable X ~ P.


Name Type Description Default

tfp.distributions.Distribution instance.


Python str prepended to names of ops created by this function.



Type Description

self.dtype Tensor with shape [B1, ..., Bn] representing n different calculations of (Shannon) cross entropy.

Source code in indl/dists/
def cross_entropy(self, other, name='cross_entropy'):
  """Computes the (Shannon) cross entropy.

  Denote this distribution (`self`) by `P` and the `other` distribution by
  `Q`. Assuming `P, Q` are absolutely continuous with respect to
  one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shannon)
  cross entropy is defined as:

  H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)

  where `F` denotes the support of the random variable `X ~ P`.

    other: `tfp.distributions.Distribution` instance.
    name: Python `str` prepended to names of ops created by this function.

    cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
      representing `n` different calculations of (Shannon) cross entropy.
  with self._name_and_control_scope(name):
    return self._cross_entropy(other)

kl_divergence(self, other, name='kl_divergence')

Computes the Kullback--Leibler divergence.

Denote this distribution (self) by p and the other distribution by q. Assuming p, q are absolutely continuous with respect to reference measure r, the KL divergence is defined as:

KL[p, q] = E_p[log(p(X)/q(X))]
         = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
         = H[p, q] - H[p]

where F denotes the support of the random variable X ~ p, H[., .] denotes (Shannon) cross entropy, and H[.] denotes (Shannon) entropy.


Name Type Description Default

tfp.distributions.Distribution instance.


Python str prepended to names of ops created by this function.



Type Description

self.dtype Tensor with shape [B1, ..., Bn] representing n different calculations of the Kullback-Leibler divergence.

Source code in indl/dists/
def kl_divergence(self, other, name='kl_divergence'):
  """Computes the Kullback--Leibler divergence.

  Denote this distribution (`self`) by `p` and the `other` distribution by
  `q`. Assuming `p, q` are absolutely continuous with respect to reference
  measure `r`, the KL divergence is defined as:

  KL[p, q] = E_p[log(p(X)/q(X))]
           = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
           = H[p, q] - H[p]

  where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
  denotes (Shannon) cross entropy, and `H[.]` denotes (Shannon) entropy.

    other: `tfp.distributions.Distribution` instance.
    name: Python `str` prepended to names of ops created by this function.

    kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
      representing `n` different calculations of the Kullback-Leibler
  # NOTE: We do not enter a `self._name_and_control_scope` here.  We rely on
  # `tfd.kl_divergence(self, other)` to use `_name_and_control_scope` to apply
  # assertions on both Distributions.
  # Subclasses that override `Distribution.kl_divergence` or `_kl_divergence`
  # must ensure that assertions are applied for both `self` and `other`.
  return self._kl_divergence(other)

TiledMVNGenerator (IProcessMVNGenerator)

Similar to LFADS' LearnableDiagonalGaussian. Uses a single learnable loc and scale which are tiled across timesteps.

Source code in indl/dists/
class TiledMVNGenerator(IProcessMVNGenerator):
    Similar to LFADS' LearnableDiagonalGaussian.
    Uses a single learnable loc and scale which are tiled across timesteps.
    def __init__(self, latent_dim: int, init_std: float = 0.1,
                 trainable_mean: bool = True, trainable_var: bool = True,
                 offdiag: bool = False):

            latent_dim: Number of dimensions in a single timestep  (params['f_latent_size'])
            init_std: Initial value of standard deviation (params['q_z_init_std'])
            trainable_mean: True if mean should be trainable (params['z_prior_train_mean'])
            trainable_var: True if variance should be trainable (params['z_prior_train_var'])
            offdiag: True if off-diagonal elements (non-orthogonality) allowed. (params['z_prior_off_diag'])
        self._offdiag = offdiag
        self._loc, self._scale = make_learnable_mvn_params(latent_dim, init_std=init_std,

    def get_dist(self, timesteps, samples=1, batch_size=1):
        Tiles the saved loc and scale to the same shape as `posterior` then uses them to
        create a MVN dist with appropriate shape. Each timestep has the same loc and
        scale but if it were sampled then each timestep would return different values.
            MVNDiag distribution of the same shape as `posterior`
        loc = tf.tile(tf.expand_dims(self._loc, 0), [timesteps, 1])
        scale = tf.expand_dims(self._scale, 0)
        if self._offdiag:
            scale = tf.tile(scale, [timesteps, 1, 1])
            dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
            scale = tf.tile(scale, [timesteps, 1])
            dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
        dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
        return dist.sample([samples, batch_size]), dist

__init__(self, latent_dim, init_std=0.1, trainable_mean=True, trainable_var=True, offdiag=False) special


Name Type Description Default
latent_dim int

Number of dimensions in a single timestep (params['f_latent_size'])

init_std float

Initial value of standard deviation (params['q_z_init_std'])

trainable_mean bool

True if mean should be trainable (params['z_prior_train_mean'])

trainable_var bool

True if variance should be trainable (params['z_prior_train_var'])

offdiag bool

True if off-diagonal elements (non-orthogonality) allowed. (params['z_prior_off_diag'])

Source code in indl/dists/
def __init__(self, latent_dim: int, init_std: float = 0.1,
             trainable_mean: bool = True, trainable_var: bool = True,
             offdiag: bool = False):

        latent_dim: Number of dimensions in a single timestep  (params['f_latent_size'])
        init_std: Initial value of standard deviation (params['q_z_init_std'])
        trainable_mean: True if mean should be trainable (params['z_prior_train_mean'])
        trainable_var: True if variance should be trainable (params['z_prior_train_var'])
        offdiag: True if off-diagonal elements (non-orthogonality) allowed. (params['z_prior_off_diag'])
    self._offdiag = offdiag
    self._loc, self._scale = make_learnable_mvn_params(latent_dim, init_std=init_std,

get_dist(self, timesteps, samples=1, batch_size=1)

Tiles the saved loc and scale to the same shape as posterior then uses them to create a MVN dist with appropriate shape. Each timestep has the same loc and scale but if it were sampled then each timestep would return different values.


Name Type Description Default
timesteps required
samples 1
batch_size 1


Type Description

MVNDiag distribution of the same shape as posterior

Source code in indl/dists/
def get_dist(self, timesteps, samples=1, batch_size=1):
    Tiles the saved loc and scale to the same shape as `posterior` then uses them to
    create a MVN dist with appropriate shape. Each timestep has the same loc and
    scale but if it were sampled then each timestep would return different values.
        MVNDiag distribution of the same shape as `posterior`
    loc = tf.tile(tf.expand_dims(self._loc, 0), [timesteps, 1])
    scale = tf.expand_dims(self._scale, 0)
    if self._offdiag:
        scale = tf.tile(scale, [timesteps, 1, 1])
        dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
        scale = tf.tile(scale, [timesteps, 1])
        dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
    dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
    return dist.sample([samples, batch_size]), dist

VariationalLSTMCell (LSTMCell)

Source code in indl/dists/
class VariationalLSTMCell(tfkl.LSTMCell):

    def __init__(self, units,
        super(VariationalLSTMCell, self).__init__(units, **kwargs)
        self.make_dist_fn = make_dist_fn
        self.make_dist_model = make_dist_model

        # For some reason the below code doesn't work during build.
        # So I don't know how to use the outer VariationalRNN to set this cell's output_size
        if self.make_dist_fn is None:
            self.make_dist_fn = lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1])
        if self.make_dist_model is None:
            fake_cell_output = tfkl.Input((self.units,))
            loc = tfkl.Dense(self.output_size, name="VarLSTMCell_loc")(fake_cell_output)
            scale = tfkl.Dense(self.output_size, name="VarLSTMCell_scale")(fake_cell_output)
            scale = tf.nn.softplus(scale + scale_shift) + 1e-5
            dist_layer = tfpl.DistributionLambda(
                # TODO: convert_to_tensor_fn=lambda s: s.sample(N_SAMPLES)
            )([loc, scale])
            self.make_dist_model = tf.keras.Model(fake_cell_output, dist_layer)

    def build(self, input_shape):
        super(VariationalLSTMCell, self).build(input_shape)
        # It would be good to defer making self.make_dist_model until here,
        # but it doesn't work for some reason.

    # def input_zero(self, inputs_):
    #    input0 = inputs_[..., -1, :]
    #    input0 = tf.matmul(input0, tf.zeros((input0.shape[-1], self.units)))
    #    dist0 = self.make_dist_model(input0)
    #    return dist0

    def call(self, inputs, states, training=None):
        inputs = tf.convert_to_tensor(inputs)
        output, state = super(VariationalLSTMCell, self).call(inputs, states, training=training)
        dist = self.make_dist_model(output)
        return dist, state

build(self, input_shape)

Creates the variables of the layer (optional, for subclass implementers).

This is a method that implementers of subclasses of Layer or Model can override if they need a state-creation step in-between layer instantiation and layer call.

This is typically used to create the weights of Layer subclasses.


Name Type Description Default

Instance of TensorShape, or list of instances of TensorShape if the layer expects a list of inputs (one instance per input).

Source code in indl/dists/
def build(self, input_shape):
    super(VariationalLSTMCell, self).build(input_shape)
    # It would be good to defer making self.make_dist_model until here,
    # but it doesn't work for some reason.

call(self, inputs, states, training=None)

This is where the layer's logic lives.

Note here that call() method in tf.keras is little bit different from keras API. In keras API, you can pass support masking for layers as additional arguments. Whereas tf.keras has compute_mask() method to support masking.


Name Type Description Default

Input tensor, or list/tuple of input tensors.


Additional keyword arguments. Currently unused.



Type Description

A tensor or list/tuple of tensors.

Source code in indl/dists/
def call(self, inputs, states, training=None):
    inputs = tf.convert_to_tensor(inputs)
    output, state = super(VariationalLSTMCell, self).call(inputs, states, training=training)
    dist = self.make_dist_model(output)
    return dist, state