dists
LearnableMultivariateNormalCell (Model)
Multivariate normal distribution RNN cell.
The model is a RNN-based recurrent function that computes the
parameters for a multivariate normal distribution at each timestep t
.
Based on:
https://github.com/tensorflow/probability/blob/698e0101aecf46c42858db7952ee3024e091c291/tensorflow_probability/examples/disentangled_vae.py#L242
Source code in indl/dists/__init__.py
class LearnableMultivariateNormalCell(tf.keras.Model):
"""Multivariate normal distribution RNN cell.
The model is a RNN-based recurrent function that computes the
parameters for a multivariate normal distribution at each timestep `t`.
Based on:
https://github.com/tensorflow/probability/blob/698e0101aecf46c42858db7952ee3024e091c291/tensorflow_probability/examples/disentangled_vae.py#L242
"""
def __init__(self, units: int, out_dim: int,
shift_std: float = 0.1, cell_type: str = 'lstm', offdiag: bool = False):
"""Constructs a learnable multivariate normal cell.
Args:
units: Dimensionality of the RNN function parameters.
out_dim: The dimensionality of the distribution.
shift_std: Shift applied to MVN std before building the dist. Providing a shift
toward the expected std allows the input values to be closer to 0.
cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
"""
super(LearnableMultivariateNormalCell, self).__init__()
self.offdiag = offdiag
self.output_dimensions = out_dim
self.units = units
if cell_type.upper().endswith('LSTM'):
self.rnn_cell = tfkl.LSTMCell(self.units, implementation=1, name="mvncell")
# why does the jupyter notebook version require implementation=1 but not in pycharm?
elif cell_type.upper().endswith('GRU'):
self.rnn_cell = tfkl.GRUCell(self.units, name="mvnell")
elif cell_type.upper().endswith('RNN'):
self.rnn_cell = tfkl.SimpleRNNCell(self.units, name="mvncell")
elif cell_type.upper().endswith('GRUCLIP'):
from indl.rnn.gru_clip import GRUClipCell
self.rnn_cell = GRUClipCell(self.units, name="mvncell")
else:
raise ValueError("cell_type %s not recognized" % cell_type)
self.loc_layer = tfkl.Dense(self.output_dimensions, name="mvncell_loc")
n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(out_dim) - out_dim) if offdiag\
else (tfpl.IndependentNormal.params_size(out_dim) - out_dim)
self.scale_untransformed_layer = tfkl.Dense(n_scale_dim, name="mvndiagcell_scale")
self._scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)
#def build(self, input_shape):
#super(LearnableMultivariateNormalDiagCell, self).build(input_shape)
#self.lstm_cell.build(input_shape)
#self.loc_layer.build(input_shape)
#self.scale_untransformed_layer.build(input_shape)
#self.built = True
def zero_state(self, sample_batch_shape=()):
"""Returns an initial state for the RNN cell.
Args:
sample_batch_shape: A 0D or 1D tensor of the combined sample and
batch shape.
Returns:
A tuple of the initial previous output at timestep 0 of shape
[sample_batch_shape, dimensions], and the cell state.
"""
zero_state = self.rnn_cell.get_initial_state(batch_size=sample_batch_shape[-1], dtype=tf.float32)
sample_batch_shape = tf.convert_to_tensor(value=sample_batch_shape, dtype=tf.int32)
out_shape = tf.concat((sample_batch_shape, [self.output_dimensions]), axis=-1)
previous_output = tf.zeros(out_shape)
return previous_output, zero_state
def call(self, inputs, state):
"""Runs the model to generate a distribution for a single timestep.
This generates a batched MultivariateNormalDiag distribution using
the output of the recurrent model at the current timestep to
parameterize the distribution.
Args:
inputs: The sampled value of `z` at the previous timestep, i.e.,
`z_{t-1}`, of shape [..., dimensions].
`z_0` should be set to the empty matrix.
state: A tuple containing the (hidden, cell) state.
Returns:
A tuple of a MultivariateNormalDiag distribution, and the state of
the recurrent function at the end of the current timestep. The
distribution will have event shape [dimensions], batch shape
[...], and sample shape [sample_shape, ..., dimensions].
"""
# In order to allow the user to pass in a single example without a batch
# dimension, we always expand the input to at least two dimensions, then
# fix the output shape to remove the batch dimension if necessary.
original_shape = inputs.shape
if len(original_shape) < 2:
inputs = tf.reshape(inputs, [1, -1])
out, state = self.rnn_cell(inputs, state)
parms_shape = tf.concat((original_shape[:-1], [self.output_dimensions]), 0)
loc = tf.reshape(self.loc_layer(out), parms_shape)
scale = self.scale_untransformed_layer(out)
scale = tf.nn.softplus(scale + self._scale_shift) + 1e-5
scale = tf.reshape(scale, parms_shape)
if self.offdiag:
return tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale), state
__init__(self, units, out_dim, shift_std=0.1, cell_type='lstm', offdiag=False)
special
Constructs a learnable multivariate normal cell.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
units |
int |
Dimensionality of the RNN function parameters. |
required |
out_dim |
int |
The dimensionality of the distribution. |
required |
shift_std |
float |
Shift applied to MVN std before building the dist. Providing a shift toward the expected std allows the input values to be closer to 0. |
0.1 |
cell_type |
str |
an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive. |
'lstm' |
offdiag |
bool |
set True to allow non-zero covariance (within-timestep) in the returned distribution. |
False |
Source code in indl/dists/__init__.py
def __init__(self, units: int, out_dim: int,
shift_std: float = 0.1, cell_type: str = 'lstm', offdiag: bool = False):
"""Constructs a learnable multivariate normal cell.
Args:
units: Dimensionality of the RNN function parameters.
out_dim: The dimensionality of the distribution.
shift_std: Shift applied to MVN std before building the dist. Providing a shift
toward the expected std allows the input values to be closer to 0.
cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
"""
super(LearnableMultivariateNormalCell, self).__init__()
self.offdiag = offdiag
self.output_dimensions = out_dim
self.units = units
if cell_type.upper().endswith('LSTM'):
self.rnn_cell = tfkl.LSTMCell(self.units, implementation=1, name="mvncell")
# why does the jupyter notebook version require implementation=1 but not in pycharm?
elif cell_type.upper().endswith('GRU'):
self.rnn_cell = tfkl.GRUCell(self.units, name="mvnell")
elif cell_type.upper().endswith('RNN'):
self.rnn_cell = tfkl.SimpleRNNCell(self.units, name="mvncell")
elif cell_type.upper().endswith('GRUCLIP'):
from indl.rnn.gru_clip import GRUClipCell
self.rnn_cell = GRUClipCell(self.units, name="mvncell")
else:
raise ValueError("cell_type %s not recognized" % cell_type)
self.loc_layer = tfkl.Dense(self.output_dimensions, name="mvncell_loc")
n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(out_dim) - out_dim) if offdiag\
else (tfpl.IndependentNormal.params_size(out_dim) - out_dim)
self.scale_untransformed_layer = tfkl.Dense(n_scale_dim, name="mvndiagcell_scale")
self._scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)
#def build(self, input_shape):
#super(LearnableMultivariateNormalDiagCell, self).build(input_shape)
#self.lstm_cell.build(input_shape)
#self.loc_layer.build(input_shape)
#self.scale_untransformed_layer.build(input_shape)
#self.built = True
call(self, inputs, state)
Runs the model to generate a distribution for a single timestep.
This generates a batched MultivariateNormalDiag distribution using the output of the recurrent model at the current timestep to parameterize the distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
The sampled value of |
required | |
state |
A tuple containing the (hidden, cell) state. |
required |
Returns:
Type | Description |
---|---|
A tuple of a MultivariateNormalDiag distribution, and the state of the recurrent function at the end of the current timestep. The distribution will have event shape [dimensions], batch shape [...], and sample shape [sample_shape, ..., dimensions]. |
Source code in indl/dists/__init__.py
def call(self, inputs, state):
"""Runs the model to generate a distribution for a single timestep.
This generates a batched MultivariateNormalDiag distribution using
the output of the recurrent model at the current timestep to
parameterize the distribution.
Args:
inputs: The sampled value of `z` at the previous timestep, i.e.,
`z_{t-1}`, of shape [..., dimensions].
`z_0` should be set to the empty matrix.
state: A tuple containing the (hidden, cell) state.
Returns:
A tuple of a MultivariateNormalDiag distribution, and the state of
the recurrent function at the end of the current timestep. The
distribution will have event shape [dimensions], batch shape
[...], and sample shape [sample_shape, ..., dimensions].
"""
# In order to allow the user to pass in a single example without a batch
# dimension, we always expand the input to at least two dimensions, then
# fix the output shape to remove the batch dimension if necessary.
original_shape = inputs.shape
if len(original_shape) < 2:
inputs = tf.reshape(inputs, [1, -1])
out, state = self.rnn_cell(inputs, state)
parms_shape = tf.concat((original_shape[:-1], [self.output_dimensions]), 0)
loc = tf.reshape(self.loc_layer(out), parms_shape)
scale = self.scale_untransformed_layer(out)
scale = tf.nn.softplus(scale + self._scale_shift) + 1e-5
scale = tf.reshape(scale, parms_shape)
if self.offdiag:
return tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale), state
zero_state(self, sample_batch_shape=())
Returns an initial state for the RNN cell.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_batch_shape |
A 0D or 1D tensor of the combined sample and batch shape. |
() |
Returns:
Type | Description |
---|---|
A tuple of the initial previous output at timestep 0 of shape [sample_batch_shape, dimensions], and the cell state. |
Source code in indl/dists/__init__.py
def zero_state(self, sample_batch_shape=()):
"""Returns an initial state for the RNN cell.
Args:
sample_batch_shape: A 0D or 1D tensor of the combined sample and
batch shape.
Returns:
A tuple of the initial previous output at timestep 0 of shape
[sample_batch_shape, dimensions], and the cell state.
"""
zero_state = self.rnn_cell.get_initial_state(batch_size=sample_batch_shape[-1], dtype=tf.float32)
sample_batch_shape = tf.convert_to_tensor(value=sample_batch_shape, dtype=tf.int32)
out_shape = tf.concat((sample_batch_shape, [self.output_dimensions]), axis=-1)
previous_output = tf.zeros(out_shape)
return previous_output, zero_state
LearnableMultivariateNormalDiag (Model)
Learnable multivariate diagonal normal distribution.
The model is a multivariate normal distribution with learnable
mean
and stddev
parameters.
See make_mvn_prior for a description.
Source code in indl/dists/__init__.py
class LearnableMultivariateNormalDiag(tf.keras.Model):
"""Learnable multivariate diagonal normal distribution.
The model is a multivariate normal distribution with learnable
`mean` and `stddev` parameters.
See make_mvn_prior for a description.
"""
def __init__(self, dimensions, init_std=1.0, trainable_mean=True, trainable_var=True):
"""Constructs a learnable multivariate diagonal normal model.
Args:
dimensions: An integer corresponding to the dimensionality of the
distribution.
"""
super(LearnableMultivariateNormalDiag, self).__init__()
with tf.name_scope(self._name):
self.dimensions = dimensions
if trainable_mean:
self._mean = tf.Variable(tf.random.normal([dimensions], stddev=0.1), name="mean")
else:
self._mean = tf.zeros(dimensions)
if trainable_var:
_scale_shift = np.log(np.exp(init_std) - 1).astype(np.float32)
self._scale = tfp.util.TransformedVariable(
tf.random.normal([dimensions], mean=init_std, stddev=init_std/10, dtype=tf.float32),
bijector=tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(_scale_shift)]),
name="transformed_scale")
else:
self._scale = init_std * tf.ones(dimensions)
def __call__(self, *args, **kwargs):
# Allow this Model to be called without inputs.
dummy = tf.zeros(self.dimensions)
return super(LearnableMultivariateNormalDiag, self).__call__(
dummy, *args, **kwargs)
def call(self, inputs):
"""Runs the model to generate multivariate normal distribution.
Args:
inputs: Unused.
Returns:
A MultivariateNormalDiag distribution with event shape
[dimensions], batch shape [], and sample shape [sample_shape,
dimensions].
"""
del inputs # unused
with tf.name_scope(self._name):
return tfd.MultivariateNormalDiag(loc=self.loc, scale_diag=self.scale_diag)
@property
def loc(self):
"""The mean of the normal distribution."""
return self._mean
@property
def scale_diag(self):
"""The diagonal standard deviation of the normal distribution."""
return self._scale
loc
property
readonly
The mean of the normal distribution.
scale_diag
property
readonly
The diagonal standard deviation of the normal distribution.
__init__(self, dimensions, init_std=1.0, trainable_mean=True, trainable_var=True)
special
Constructs a learnable multivariate diagonal normal model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dimensions |
An integer corresponding to the dimensionality of the distribution. |
required |
Source code in indl/dists/__init__.py
def __init__(self, dimensions, init_std=1.0, trainable_mean=True, trainable_var=True):
"""Constructs a learnable multivariate diagonal normal model.
Args:
dimensions: An integer corresponding to the dimensionality of the
distribution.
"""
super(LearnableMultivariateNormalDiag, self).__init__()
with tf.name_scope(self._name):
self.dimensions = dimensions
if trainable_mean:
self._mean = tf.Variable(tf.random.normal([dimensions], stddev=0.1), name="mean")
else:
self._mean = tf.zeros(dimensions)
if trainable_var:
_scale_shift = np.log(np.exp(init_std) - 1).astype(np.float32)
self._scale = tfp.util.TransformedVariable(
tf.random.normal([dimensions], mean=init_std, stddev=init_std/10, dtype=tf.float32),
bijector=tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(_scale_shift)]),
name="transformed_scale")
else:
self._scale = init_std * tf.ones(dimensions)
call(self, inputs)
Runs the model to generate multivariate normal distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Unused. |
required |
Returns:
Type | Description |
---|---|
A MultivariateNormalDiag distribution with event shape [dimensions], batch shape [], and sample shape [sample_shape, dimensions]. |
Source code in indl/dists/__init__.py
def call(self, inputs):
"""Runs the model to generate multivariate normal distribution.
Args:
inputs: Unused.
Returns:
A MultivariateNormalDiag distribution with event shape
[dimensions], batch shape [], and sample shape [sample_shape,
dimensions].
"""
del inputs # unused
with tf.name_scope(self._name):
return tfd.MultivariateNormalDiag(loc=self.loc, scale_diag=self.scale_diag)
make_learnable_mvn_params(ndim, init_std=1.0, trainable_mean=True, trainable_var=True, offdiag=False)
Return mean (loc) and stddev (scale) parameters for initializing multivariate normal distributions. If trainable_mean then it will be initialized with random normal (stddev=0.1), otherwise zeros. If trainable_var then it will be initialized with random normal centered at a value such that the bijector transformation yields the value in init_std. When init_std is 1.0 (default) then the inverse-bijected value is approximately 0.0. If not trainable_var then scale is a vector or matrix of init_std of appropriate shape for the dist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ndim |
int |
Number of dimensions. |
required |
init_std |
float |
Initial value for the standard deviation. |
1.0 |
trainable_mean |
bool |
Whether or not the mean (loc) is a trainable tf.Variable. |
True |
trainable_var |
bool |
Whether or not the variance (scale) is a trainable tf.Variable. |
True |
offdiag |
bool |
Whether or not off-diagonal elements are allowed. |
False |
Returns: loc, scale
Source code in indl/dists/__init__.py
def make_learnable_mvn_params(ndim: int, init_std: float = 1.0, trainable_mean: bool = True, trainable_var: bool = True,
offdiag: bool = False):
"""
Return mean (loc) and stddev (scale) parameters for initializing multivariate normal distributions.
If trainable_mean then it will be initialized with random normal (stddev=0.1), otherwise zeros.
If trainable_var then it will be initialized with random normal centered at a value such that the
bijector transformation yields the value in init_std. When init_std is 1.0 (default) then the
inverse-bijected value is approximately 0.0.
If not trainable_var then scale is a vector or matrix of init_std of appropriate shape for the dist.
Args:
ndim: Number of dimensions.
init_std: Initial value for the standard deviation.
trainable_mean: Whether or not the mean (loc) is a trainable tf.Variable.
trainable_var: Whether or not the variance (scale) is a trainable tf.Variable.
offdiag: Whether or not off-diagonal elements are allowed.
Returns: loc, scale
"""
if trainable_mean:
loc = tf.Variable(tf.random.normal([ndim], stddev=0.1, dtype=tf.float32))
else:
loc = tf.zeros(ndim)
# Initialize the variance (scale), trainable or not, offdiag or not.
if trainable_var:
if offdiag:
_ndim = [ndim, ndim]
scale = tfp.util.TransformedVariable(
# init_std * tf.eye(ndim, dtype=tf.float32),
tf.random.normal(_ndim, mean=init_std, stddev=init_std/10, dtype=tf.float32),
tfp.bijectors.FillScaleTriL(),
name="prior_scale")
else:
_scale_shift = np.log(np.exp(init_std) - 1).astype(np.float32) # tfp.math.softplus_inverse(init_std)
scale = tfp.util.TransformedVariable(
# init_std * tf.ones(ndim, dtype=tf.float32),
tf.random.normal([ndim], mean=init_std, stddev=init_std/10, dtype=tf.float32),
tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(_scale_shift)]),
name="prior_scale")
else:
if offdiag:
scale = init_std * tf.eye(ndim)
else:
scale = init_std * tf.ones(ndim)
return loc, scale
make_mvn_dist_fn(_x_, ndim, shift_std=1.0, offdiag=False, loc_name=None, scale_name=None, use_mvn_diag=True)
Take a 1-D tensor and use it to parameterize a MVN dist. This doesn't return the distribution, but the function to make the distribution and its arguments. make_dist_fn, [loc, scale] You can supply it to tfpl.DistributionLambda
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_x_ |
Tensor |
required | |
ndim |
int |
required | |
shift_std |
float |
1.0 |
|
offdiag |
bool |
False |
|
loc_name |
Optional[str] |
None |
|
scale_name |
Optional[str] |
None |
|
use_mvn_diag |
bool |
True |
Returns:
Type | Description |
---|---|
Tuple[Callable[[tensorflow.python.framework.ops.Tensor, tensorflow.python.framework.ops.Tensor], tensorflow_probability.python.distributions.distribution.Distribution], List[tensorflow.python.framework.ops.Tensor]] |
make_dist_fn, [loc, scale] |
Source code in indl/dists/__init__.py
def make_mvn_dist_fn(_x_: tf.Tensor, ndim: int, shift_std: float = 1.0, offdiag: bool = False,
loc_name: Optional[str] = None, scale_name: Optional[str] = None, use_mvn_diag: bool = True
) -> Tuple[Callable[[tf.Tensor, tf.Tensor], tfd.Distribution], List[tf.Tensor]]:
"""
Take a 1-D tensor and use it to parameterize a MVN dist.
This doesn't return the distribution, but the function to make the distribution and its arguments.
make_dist_fn, [loc, scale]
You can supply it to tfpl.DistributionLambda
Args:
_x_:
ndim:
shift_std:
offdiag:
loc_name:
scale_name:
use_mvn_diag:
Returns:
make_dist_fn, [loc, scale]
"""
_scale_shift = np.log(np.exp(shift_std) - 1).astype(np.float32)
_loc = tfkl.Dense(ndim, name=loc_name)(_x_)
n_scale_dim = (tfpl.MultivariateNormalTriL.params_size(ndim) - ndim) if offdiag\
else (tfpl.IndependentNormal.params_size(ndim) - ndim)
_scale = tfkl.Dense(n_scale_dim, name=scale_name)(_x_)
_scale = tf.math.softplus(_scale + _scale_shift) + 1e-5
if offdiag:
_scale = tfb.FillTriangular()(_scale)
make_dist_fn = lambda t: tfd.MultivariateNormalTriL(loc=t[0], scale_tril=t[1])
else:
if use_mvn_diag: # Match type with prior
make_dist_fn = lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1])
else:
make_dist_fn = lambda t: tfd.Independent(tfd.Normal(loc=t[0], scale=t[1]))
return make_dist_fn, [_loc, _scale]
make_mvn_prior(ndim, init_std=1.0, trainable_mean=True, trainable_var=True, offdiag=False)
Creates a tensorflow-probability distribution: MultivariateNormalTriL if offdiag else MultivariateNormalDiag Mean (loc) and sigma (scale) can be trainable or not. Mean initializes to random.normal around 0 (stddev=0.1) if trainable, else zeros. Scale initialies to init_std if not trainable. If it is trainable, it initializes to a tfp TransformedVariable that will be centered at 0 for easy training under the hood, but will be transformed via softplus to give something initially close to init_var.
loc and scale are tracked by the MVNDiag class.
For LFADS ics prior, trainable_mean=True, trainable_var=False For LFADS cos prior (if not using AR1), trainable_mean=False, trainable_var=True In either case, var was initialized with 0.1 (==> logvar with log(0.1)) Unlike the LFADS' LearnableDiagonalGaussian, here we don't support multi-dimensional, just a vector.
See also LearnableMultivariateNormalDiag for a tf.keras.Model version of this.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ndim |
int |
latent dimension of distribution. Currently only supports 1 d (I think ) |
required |
init_std |
float |
initial standard deviation of the gaussian. If trainable_var then the initial standard deviation will be drawn from a random.normal distribution with mean init_std and stddev 1/10th of that. |
1.0 |
trainable_mean |
bool |
If the mean should be a tf.Variable |
True |
trainable_var |
bool |
If the variance/stddev/scale (whatever you call it) should be a tf.Variable |
True |
offdiag |
bool |
If the variance-covariance matrix is allowed non-zero off-diagonal elements. |
False |
Returns:
Type | Description |
---|---|
Union[tensorflow_probability.python.distributions.mvn_diag.MultivariateNormalDiag, tensorflow_probability.python.distributions.mvn_tril.MultivariateNormalTriL] |
A tensorflow-probability distribution (either MultivariateNormalTriL or MultivariateNormalDiag). |
Source code in indl/dists/__init__.py
def make_mvn_prior(ndim: int, init_std: float = 1.0, trainable_mean: bool = True, trainable_var: bool = True,
offdiag: bool = False) -> Union[tfd.MultivariateNormalDiag, tfd.MultivariateNormalTriL]:
"""
Creates a tensorflow-probability distribution:
MultivariateNormalTriL if offdiag else MultivariateNormalDiag
Mean (loc) and sigma (scale) can be trainable or not.
Mean initializes to random.normal around 0 (stddev=0.1) if trainable, else zeros.
Scale initialies to init_std if not trainable. If it is trainable, it initializes
to a tfp TransformedVariable that will be centered at 0 for easy training under the hood,
but will be transformed via softplus to give something initially close to init_var.
loc and scale are tracked by the MVNDiag class.
For LFADS ics prior, trainable_mean=True, trainable_var=False
For LFADS cos prior (if not using AR1), trainable_mean=False, trainable_var=True
In either case, var was initialized with 0.1 (==> logvar with log(0.1))
Unlike the LFADS' LearnableDiagonalGaussian, here we don't support multi-dimensional, just a vector.
See also LearnableMultivariateNormalDiag for a tf.keras.Model version of this.
Args:
ndim: latent dimension of distribution. Currently only supports 1 d (I think )
init_std: initial standard deviation of the gaussian. If trainable_var then the initial standard deviation
will be drawn from a random.normal distribution with mean init_std and stddev 1/10th of that.
trainable_mean: If the mean should be a tf.Variable
trainable_var: If the variance/stddev/scale (whatever you call it) should be a tf.Variable
offdiag: If the variance-covariance matrix is allowed non-zero off-diagonal elements.
Returns:
A tensorflow-probability distribution (either MultivariateNormalTriL or MultivariateNormalDiag).
"""
loc, scale = make_learnable_mvn_params(ndim, init_std=init_std,
trainable_mean=trainable_mean, trainable_var=trainable_var,
offdiag=offdiag)
# Initialize the prior.
if offdiag:
# Note: Diag must be > 0, upper triangular must be 0, and lower triangular may be != 0.
prior = tfd.MultivariateNormalTriL(
loc=loc,
scale_tril=scale
)
else:
prior = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
# kl_exact needs same dist types for prior and latent.
# We would switch to the next line if we switched our latent to using tf.Independent(tfd.Normal)
# prior = tfd.Independent(tfd.Normal(loc=tf.zeros(ndim), scale=1), reinterpreted_batch_ndims=1)
return prior
make_variational(x, dist_dim, init_std=1.0, offdiag=False, samps=1, loc_name='loc', scale_name='scale', dist_name='q', use_mvn_diag=True)
Take an input tensor and return a multivariate normal distribution parameterized by that input tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
input tensor |
required |
dist_dim |
int |
the dimensionality of the distribution |
required |
init_std |
float |
initial stddev SHIFT of the distribution when input is 0. |
1.0 |
offdiag |
bool |
whether or not to include covariances |
False |
samps |
int |
the number of samples to draw when using implied convert_to_tensor_fn |
1 |
loc_name |
not used (I need to handle naming better) |
'loc' |
|
scale_name |
not used |
'scale' |
|
dist_name |
not used |
'q' |
|
use_mvn_diag |
bool |
whether to use tfd.MultivariateNormal(Diag|TriL) (True) or tfd.Independent(tfd.Normal) (False) Latter is untested. Note that the mvn dists will put the timesteps dimension (if present in the input) into the "batch dimension" while the "event" dimension will be the last dimension only. You can use tfd.Independent(q_dist, reinterpreted_batch_ndims=1) to move the timestep dimension to the event dimension if necessary. (tfd.Independent doesn't play well with tf.keras.Model inputs/outputs). |
True |
Returns:
Type | Description |
---|---|
Union[tensorflow_probability.python.distributions.mvn_diag.MultivariateNormalDiag, tensorflow_probability.python.distributions.mvn_tril.MultivariateNormalTriL, tensorflow_probability.python.distributions.independent.Independent] |
A tfd.Distribution. The distribution is of type MultivariateNormalDiag (or MultivariateNormalTriL if offdiag) if use_mvn_diag is set. |
Source code in indl/dists/__init__.py
def make_variational(x: tf.Tensor, dist_dim: int,
init_std: float = 1.0, offdiag: bool = False,
samps: int = 1,
loc_name="loc", scale_name="scale",
dist_name="q",
use_mvn_diag: bool = True
) -> Union[tfd.MultivariateNormalDiag, tfd.MultivariateNormalTriL, tfd.Independent]:
"""
Take an input tensor and return a multivariate normal distribution parameterized by that input tensor.
Args:
x: input tensor
dist_dim: the dimensionality of the distribution
init_std: initial stddev SHIFT of the distribution when input is 0.
offdiag: whether or not to include covariances
samps: the number of samples to draw when using implied convert_to_tensor_fn
loc_name: not used (I need to handle naming better)
scale_name: not used
dist_name: not used
use_mvn_diag: whether to use tfd.MultivariateNormal(Diag|TriL) (True) or tfd.Independent(tfd.Normal) (False)
Latter is untested. Note that the mvn dists will put the timesteps dimension (if present in the input)
into the "batch dimension" while the "event" dimension will be the last dimension only.
You can use tfd.Independent(q_dist, reinterpreted_batch_ndims=1) to move the timestep dimension to the
event dimension if necessary. (tfd.Independent doesn't play well with tf.keras.Model inputs/outputs).
Returns:
A tfd.Distribution. The distribution is of type MultivariateNormalDiag (or MultivariateNormalTriL if offdiag)
if use_mvn_diag is set.
"""
make_dist_fn, dist_params = make_mvn_dist_fn(x, dist_dim, shift_std=init_std,
offdiag=offdiag,
# loc_name=loc_name, scale_name=scale_name,
use_mvn_diag=use_mvn_diag)
# Python `callable` that takes a `tfd.Distribution`
# instance and returns a `tf.Tensor`-like object.
"""
# Unfortunately I couldn't get this to work :(
# Will have to explicitly q_f.value() | qf.mean() from dist in train_step
def custom_convert_fn(d, training=None):
if training is None:
training = K.learning_phase()
output = tf_utils.smart_cond(training,
lambda: d.sample(samps),
lambda: d.mean()
)
return output
def convert_fn(d):
return K.in_train_phase(tfd.Distribution.sample if samps <= 1 else lambda: d.sample(samps),
lambda: d.mean())
"""
convert_fn = tfd.Distribution.sample if samps <= 1 else lambda d: d.sample(samps)
q_dist = tfpl.DistributionLambda(make_distribution_fn=make_dist_fn,
convert_to_tensor_fn=convert_fn,
)(dist_params)
# if tf.shape(x).shape[0] > 2:
# q_dist = tfd.Independent(q_dist, reinterpreted_batch_ndims=1)
return q_dist
sequential
AR1ProcessMVNGenerator (IProcessMVNGenerator)
Similar to LFADS' LearnableAutoRegressive1Prior. Here we use the terminology from: https://en.wikipedia.org/wiki/Autoregressive_model#Example:_An_AR(1)_process
The autoregressive function takes the form: E(X_t) = E(c) + phi * E(X_{t-1}) + e_t E(c) is a constant. phi is a parameter, which is equivalent to exp(-1/tau) = exp(-exp(-logtau)). where tau is a time constant. e_t is white noise with zero-mean with evar = sigma_e**2
When there's no previous sample, E(X_t) = E(c) + e_t, which is a draw from N(c, sigma_e2) When there is a previous sample, E(X_t) = E(c) + phi * E(X_{t-1}) + e_t, which means a draw from N(c + phi * X_{t-1}, sigma_p2) where sigma_p2 = phi2 * var(X_{t-1}) + sigma_e2 = sigma_e2 / (1 - phi**2) or logpvar = logevar - (log(1 - phi) + log(1 + phi))
Note that this could be roughly equivalent to tfd.Autoregressive if it was passed
a distribution_fn
with the same transition.
See issue: https://github.com/snel-repo/lfads-cd/issues/1
Source code in indl/dists/sequential.py
class AR1ProcessMVNGenerator(IProcessMVNGenerator):
"""
Similar to LFADS' LearnableAutoRegressive1Prior.
Here we use the terminology from:
https://en.wikipedia.org/wiki/Autoregressive_model#Example:_An_AR(1)_process
The autoregressive function takes the form:
E(X_t) = E(c) + phi * E(X_{t-1}) + e_t
E(c) is a constant.
phi is a parameter, which is equivalent to exp(-1/tau) = exp(-exp(-logtau)).
where tau is a time constant.
e_t is white noise with zero-mean with evar = sigma_e**2
When there's no previous sample, E(X_t) = E(c) + e_t,
which is a draw from N(c, sigma_e**2)
When there is a previous sample, E(X_t) = E(c) + phi * E(X_{t-1}) + e_t,
which means a draw from N(c + phi * X_{t-1}, sigma_p**2)
where sigma_p**2 = phi**2 * var(X_{t-1}) + sigma_e**2 = sigma_e**2 / (1 - phi**2)
or logpvar = logevar - (log(1 - phi) + log(1 + phi))
Note that this could be roughly equivalent to tfd.Autoregressive if it was passed
a `distribution_fn` with the same transition.
See issue: https://github.com/snel-repo/lfads-cd/issues/1
"""
def __init__(self, init_taus: Union[float, List[float]],
init_std: Union[float, List[float]] = 0.1,
trainable_mean: bool = False,
trainable_tau: bool = True,
trainable_var: bool = True,
offdiag: bool = False):
"""
Args:
init_taus: Initial values of tau
init_std: Initial value of sigma_e
trainable_mean: set True if the mean (e_c) is trainable.
trainable_tau: set True to
trainable_nvar:
"""
self._offdiag = offdiag
if isinstance(init_taus, float):
init_taus = [init_taus]
# TODO: Add time axis for easier broadcasting
ndim = len(init_taus)
self._e_c, self._e_scale = make_learnable_mvn_params(ndim, init_std=init_std,
trainable_mean=trainable_mean,
trainable_var=trainable_var,
offdiag=offdiag)
self._logtau = tf.Variable(tf.math.log(init_taus), dtype=tf.float32, trainable=trainable_tau)
self._phi = tf.exp(-tf.exp(-self._logtau))
self._p_scale = tf.exp(tf.math.log(self._e_scale) - (tf.math.log(1 - self._phi) + tf.math.log(1 + self._phi)))
def get_dist(self, timesteps, samples=1, batch_size=1, fixed=False):
locs = []
scales = []
sample_list = []
# Add a time dimension
e_c = tf.expand_dims(self._e_c, 0)
e_scale = tf.expand_dims(self._e_scale, 0)
p_scale = tf.expand_dims(self._p_scale, 0)
sample = tf.expand_dims(tf.expand_dims(tf.zeros_like(e_c), 0), 0)
sample = tf.tile(sample, [samples, batch_size, 1, 1])
for _ in range(timesteps):
loc = e_c + self._phi * sample
scale = p_scale if _ > 0 else e_scale
locs.append(loc)
scales.append(scale)
if self._offdiag:
dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
sample = dist.sample()
sample_list.append(sample)
sample = tf.concat(sample_list, axis=2)
loc = tf.concat(locs, axis=2)
scale = tf.concat(scales, axis=-2)
if self._offdiag:
dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
return sample, dist
__init__(self, init_taus, init_std=0.1, trainable_mean=False, trainable_tau=True, trainable_var=True, offdiag=False)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_taus |
Union[float, List[float]] |
Initial values of tau |
required |
init_std |
Union[float, List[float]] |
Initial value of sigma_e |
0.1 |
trainable_mean |
bool |
set True if the mean (e_c) is trainable. |
False |
trainable_tau |
bool |
set True to |
True |
trainable_nvar |
required |
Source code in indl/dists/sequential.py
def __init__(self, init_taus: Union[float, List[float]],
init_std: Union[float, List[float]] = 0.1,
trainable_mean: bool = False,
trainable_tau: bool = True,
trainable_var: bool = True,
offdiag: bool = False):
"""
Args:
init_taus: Initial values of tau
init_std: Initial value of sigma_e
trainable_mean: set True if the mean (e_c) is trainable.
trainable_tau: set True to
trainable_nvar:
"""
self._offdiag = offdiag
if isinstance(init_taus, float):
init_taus = [init_taus]
# TODO: Add time axis for easier broadcasting
ndim = len(init_taus)
self._e_c, self._e_scale = make_learnable_mvn_params(ndim, init_std=init_std,
trainable_mean=trainable_mean,
trainable_var=trainable_var,
offdiag=offdiag)
self._logtau = tf.Variable(tf.math.log(init_taus), dtype=tf.float32, trainable=trainable_tau)
self._phi = tf.exp(-tf.exp(-self._logtau))
self._p_scale = tf.exp(tf.math.log(self._e_scale) - (tf.math.log(1 - self._phi) + tf.math.log(1 + self._phi)))
RNNMVNGenerator (IProcessMVNGenerator)
Similar to DSAE's LearnableMultivariateNormalDiagCell
Source code in indl/dists/sequential.py
class RNNMVNGenerator(IProcessMVNGenerator):
"""
Similar to DSAE's LearnableMultivariateNormalDiagCell
"""
def __init__(self, units: int, out_dim: int, cell_type: str, shift_std: float = 0.1, offdiag: bool = False):
"""
Args:
units: Dimensionality of the RNN function parameters.
out_dim: The dimensionality of the distribution.
cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
shift_std: Shift applied to MVN std before building the dist. Providing a shift
toward the expected std allows the input values to be closer to 0.
offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
"""
self.cell = LearnableMultivariateNormalCell(units, out_dim, cell_type=cell_type,
shift_std=shift_std, offdiag=offdiag)
def get_dist(self, timesteps, samples=1, batch_size=1, fixed=True):
"""
Samples from self.cell `timesteps` times.
On each step, the previous (sample, state) is fed back into the cell
(zero_state used for 0th step).
The cell returns a multivariate normal diagonal distribution for each timestep.
We collect each timestep-dist's params (loc and scale), then use them to create
the return value: a single MVN diag dist that has a dimension for timesteps.
The cell returns a full dist for each timestep so that we can 'sample' it.
If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
to doing a generative RNN (init state = zeros, return_sequences=True) then passing
those values through a pair of Dense layers to parameterize a single MVNDiag.
Args:
timesteps: Number of times to sample from the dynamic_prior_cell. Output will have
samples: Number of samples to draw from the latent distribution.
batch_size: Number of sequences to sample.
fixed: Boolean for whether or not to share the same random
sample across all sequences in batch.
https://github.com/tensorflow/probability/blob/698e0101aecf46c42858db7952ee3024e091c291/tensorflow_probability/examples/disentangled_vae.py#L887
Returns:
"""
if fixed:
sample_batch_size = 1
else:
sample_batch_size = batch_size
sample, state = self.cell.zero_state([samples, sample_batch_size])
locs = []
scales = []
sample_list = []
scale_parm_name = "scale_tril" if self.cell.offdiag else "scale_diag" # TODO: Check this for offdiag
for _ in range(timesteps):
dist, state = self.cell(sample, state)
sample = dist.sample()
locs.append(dist.parameters["loc"])
scales.append(dist.parameters[scale_parm_name])
sample_list.append(sample)
sample = tf.stack(sample_list, axis=2)
loc = tf.stack(locs, axis=2)
scale = tf.stack(scales, axis=2)
if fixed: # tile along the batch axis
sample = sample + tf.zeros([batch_size, 1, 1])
if self.cell.offdiag:
dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
return sample, dist
__init__(self, units, out_dim, cell_type, shift_std=0.1, offdiag=False)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
units |
int |
Dimensionality of the RNN function parameters. |
required |
out_dim |
int |
The dimensionality of the distribution. |
required |
cell_type |
str |
an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive. |
required |
shift_std |
float |
Shift applied to MVN std before building the dist. Providing a shift toward the expected std allows the input values to be closer to 0. |
0.1 |
offdiag |
bool |
set True to allow non-zero covariance (within-timestep) in the returned distribution. |
False |
Source code in indl/dists/sequential.py
def __init__(self, units: int, out_dim: int, cell_type: str, shift_std: float = 0.1, offdiag: bool = False):
"""
Args:
units: Dimensionality of the RNN function parameters.
out_dim: The dimensionality of the distribution.
cell_type: an RNN cell type among 'lstm', 'gru', 'rnn', 'gruclip'. case-insensitive.
shift_std: Shift applied to MVN std before building the dist. Providing a shift
toward the expected std allows the input values to be closer to 0.
offdiag: set True to allow non-zero covariance (within-timestep) in the returned distribution.
"""
self.cell = LearnableMultivariateNormalCell(units, out_dim, cell_type=cell_type,
shift_std=shift_std, offdiag=offdiag)
get_dist(self, timesteps, samples=1, batch_size=1, fixed=True)
Samples from self.cell `timesteps` times.
On each step, the previous (sample, state) is fed back into the cell
(zero_state used for 0th step).
The cell returns a multivariate normal diagonal distribution for each timestep.
We collect each timestep-dist's params (loc and scale), then use them to create
the return value: a single MVN diag dist that has a dimension for timesteps.
The cell returns a full dist for each timestep so that we can 'sample' it.
If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
to doing a generative RNN (init state = zeros, return_sequences=True) then passing
those values through a pair of Dense layers to parameterize a single MVNDiag.
!!! args
timesteps: Number of times to sample from the dynamic_prior_cell. Output will have
samples: Number of samples to draw from the latent distribution.
batch_size: Number of sequences to sample.
!!! fixed "Boolean for whether or not to share the same random"
sample across all sequences in batch.
https://github.com/tensorflow/probability/blob/698e0101aecf46c42858db7952ee3024e091c291/tensorflow_probability/examples/disentangled_vae.py#L887 Returns:
Source code in indl/dists/sequential.py
def get_dist(self, timesteps, samples=1, batch_size=1, fixed=True):
"""
Samples from self.cell `timesteps` times.
On each step, the previous (sample, state) is fed back into the cell
(zero_state used for 0th step).
The cell returns a multivariate normal diagonal distribution for each timestep.
We collect each timestep-dist's params (loc and scale), then use them to create
the return value: a single MVN diag dist that has a dimension for timesteps.
The cell returns a full dist for each timestep so that we can 'sample' it.
If our sample size is 1, and our cell is an RNN cell, then this is roughly equivalent
to doing a generative RNN (init state = zeros, return_sequences=True) then passing
those values through a pair of Dense layers to parameterize a single MVNDiag.
Args:
timesteps: Number of times to sample from the dynamic_prior_cell. Output will have
samples: Number of samples to draw from the latent distribution.
batch_size: Number of sequences to sample.
fixed: Boolean for whether or not to share the same random
sample across all sequences in batch.
https://github.com/tensorflow/probability/blob/698e0101aecf46c42858db7952ee3024e091c291/tensorflow_probability/examples/disentangled_vae.py#L887
Returns:
"""
if fixed:
sample_batch_size = 1
else:
sample_batch_size = batch_size
sample, state = self.cell.zero_state([samples, sample_batch_size])
locs = []
scales = []
sample_list = []
scale_parm_name = "scale_tril" if self.cell.offdiag else "scale_diag" # TODO: Check this for offdiag
for _ in range(timesteps):
dist, state = self.cell(sample, state)
sample = dist.sample()
locs.append(dist.parameters["loc"])
scales.append(dist.parameters[scale_parm_name])
sample_list.append(sample)
sample = tf.stack(sample_list, axis=2)
loc = tf.stack(locs, axis=2)
scale = tf.stack(scales, axis=2)
if fixed: # tile along the batch axis
sample = sample + tf.zeros([batch_size, 1, 1])
if self.cell.offdiag:
dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
return sample, dist
RNNMultivariateNormalDiag (MultivariateNormalDiag)
Source code in indl/dists/sequential.py
class RNNMultivariateNormalDiag(tfd.MultivariateNormalDiag):
def __init__(self, cell, n_timesteps=1, output_dim=None, name="rnn_mvn_diag", **kwargs):
self.cell = cell
if output_dim is not None and hasattr(self.cell, 'output_dim'):
self.cell.output_dim = output_dim
if hasattr(self.cell, 'output_dim'):
output_dim = self.cell.output_dim
else:
output_dim = output_dim or self.cell.units
h0 = tf.zeros([1, self.cell.units])
c0 = tf.zeros([1, self.cell.units])
input0 = tf.zeros((1, output_dim))
if hasattr(cell, 'reset_dropout_mask'):
self.cell.reset_dropout_mask()
self.cell.reset_recurrent_dropout_mask()
input_ = input0
states_ = (h0, c0)
successive_outputs = []
for i in range(n_timesteps):
input_, states_ = self.cell(input_, states_)
successive_outputs.append(input_)
loc = tf.concat([_.parameters["distribution"].parameters["loc"]
for _ in successive_outputs],
axis=0)
scale_diag = tf.concat([_.parameters["distribution"].parameters["scale_diag"]
for _ in successive_outputs],
axis=0)
super(RNNMultivariateNormalDiag, self).__init__(loc=loc, scale_diag=scale_diag, name=name, **kwargs)
cross_entropy(self, other, name='cross_entropy')
Computes the (Shannon) cross entropy.
Denote this distribution (self
) by P
and the other
distribution by
Q
. Assuming P, Q
are absolutely continuous with respect to
one another and permit densities p(x) dr(x)
and q(x) dr(x)
, (Shannon)
cross entropy is defined as:
H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
where F
denotes the support of the random variable X ~ P
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
|
required | |
name |
Python |
'cross_entropy' |
Returns:
Type | Description |
---|---|
cross_entropy |
|
Source code in indl/dists/sequential.py
def cross_entropy(self, other, name='cross_entropy'):
"""Computes the (Shannon) cross entropy.
Denote this distribution (`self`) by `P` and the `other` distribution by
`Q`. Assuming `P, Q` are absolutely continuous with respect to
one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shannon)
cross entropy is defined as:
```none
H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
```
where `F` denotes the support of the random variable `X ~ P`.
Args:
other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
representing `n` different calculations of (Shannon) cross entropy.
"""
with self._name_and_control_scope(name):
return self._cross_entropy(other)
kl_divergence(self, other, name='kl_divergence')
Computes the Kullback--Leibler divergence.
Denote this distribution (self
) by p
and the other
distribution by
q
. Assuming p, q
are absolutely continuous with respect to reference
measure r
, the KL divergence is defined as:
KL[p, q] = E_p[log(p(X)/q(X))]
= -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
= H[p, q] - H[p]
where F
denotes the support of the random variable X ~ p
, H[., .]
denotes (Shannon) cross entropy, and H[.]
denotes (Shannon) entropy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
|
required | |
name |
Python |
'kl_divergence' |
Returns:
Type | Description |
---|---|
kl_divergence |
|
Source code in indl/dists/sequential.py
def kl_divergence(self, other, name='kl_divergence'):
"""Computes the Kullback--Leibler divergence.
Denote this distribution (`self`) by `p` and the `other` distribution by
`q`. Assuming `p, q` are absolutely continuous with respect to reference
measure `r`, the KL divergence is defined as:
```none
KL[p, q] = E_p[log(p(X)/q(X))]
= -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
= H[p, q] - H[p]
```
where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
denotes (Shannon) cross entropy, and `H[.]` denotes (Shannon) entropy.
Args:
other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
representing `n` different calculations of the Kullback-Leibler
divergence.
"""
# NOTE: We do not enter a `self._name_and_control_scope` here. We rely on
# `tfd.kl_divergence(self, other)` to use `_name_and_control_scope` to apply
# assertions on both Distributions.
#
# Subclasses that override `Distribution.kl_divergence` or `_kl_divergence`
# must ensure that assertions are applied for both `self` and `other`.
return self._kl_divergence(other)
TiledMVNGenerator (IProcessMVNGenerator)
Similar to LFADS' LearnableDiagonalGaussian. Uses a single learnable loc and scale which are tiled across timesteps.
Source code in indl/dists/sequential.py
class TiledMVNGenerator(IProcessMVNGenerator):
"""
Similar to LFADS' LearnableDiagonalGaussian.
Uses a single learnable loc and scale which are tiled across timesteps.
"""
def __init__(self, latent_dim: int, init_std: float = 0.1,
trainable_mean: bool = True, trainable_var: bool = True,
offdiag: bool = False):
"""
Args:
latent_dim: Number of dimensions in a single timestep (params['f_latent_size'])
init_std: Initial value of standard deviation (params['q_z_init_std'])
trainable_mean: True if mean should be trainable (params['z_prior_train_mean'])
trainable_var: True if variance should be trainable (params['z_prior_train_var'])
offdiag: True if off-diagonal elements (non-orthogonality) allowed. (params['z_prior_off_diag'])
"""
self._offdiag = offdiag
self._loc, self._scale = make_learnable_mvn_params(latent_dim, init_std=init_std,
trainable_mean=trainable_mean,
trainable_var=trainable_var,
offdiag=offdiag)
def get_dist(self, timesteps, samples=1, batch_size=1):
"""
Tiles the saved loc and scale to the same shape as `posterior` then uses them to
create a MVN dist with appropriate shape. Each timestep has the same loc and
scale but if it were sampled then each timestep would return different values.
Args:
timesteps:
samples:
batch_size:
Returns:
MVNDiag distribution of the same shape as `posterior`
"""
loc = tf.tile(tf.expand_dims(self._loc, 0), [timesteps, 1])
scale = tf.expand_dims(self._scale, 0)
if self._offdiag:
scale = tf.tile(scale, [timesteps, 1, 1])
dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
scale = tf.tile(scale, [timesteps, 1])
dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
return dist.sample([samples, batch_size]), dist
__init__(self, latent_dim, init_std=0.1, trainable_mean=True, trainable_var=True, offdiag=False)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
latent_dim |
int |
Number of dimensions in a single timestep (params['f_latent_size']) |
required |
init_std |
float |
Initial value of standard deviation (params['q_z_init_std']) |
0.1 |
trainable_mean |
bool |
True if mean should be trainable (params['z_prior_train_mean']) |
True |
trainable_var |
bool |
True if variance should be trainable (params['z_prior_train_var']) |
True |
offdiag |
bool |
True if off-diagonal elements (non-orthogonality) allowed. (params['z_prior_off_diag']) |
False |
Source code in indl/dists/sequential.py
def __init__(self, latent_dim: int, init_std: float = 0.1,
trainable_mean: bool = True, trainable_var: bool = True,
offdiag: bool = False):
"""
Args:
latent_dim: Number of dimensions in a single timestep (params['f_latent_size'])
init_std: Initial value of standard deviation (params['q_z_init_std'])
trainable_mean: True if mean should be trainable (params['z_prior_train_mean'])
trainable_var: True if variance should be trainable (params['z_prior_train_var'])
offdiag: True if off-diagonal elements (non-orthogonality) allowed. (params['z_prior_off_diag'])
"""
self._offdiag = offdiag
self._loc, self._scale = make_learnable_mvn_params(latent_dim, init_std=init_std,
trainable_mean=trainable_mean,
trainable_var=trainable_var,
offdiag=offdiag)
get_dist(self, timesteps, samples=1, batch_size=1)
Tiles the saved loc and scale to the same shape as posterior
then uses them to
create a MVN dist with appropriate shape. Each timestep has the same loc and
scale but if it were sampled then each timestep would return different values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timesteps |
required | ||
samples |
1 |
||
batch_size |
1 |
Returns:
Type | Description |
---|---|
MVNDiag distribution of the same shape as |
Source code in indl/dists/sequential.py
def get_dist(self, timesteps, samples=1, batch_size=1):
"""
Tiles the saved loc and scale to the same shape as `posterior` then uses them to
create a MVN dist with appropriate shape. Each timestep has the same loc and
scale but if it were sampled then each timestep would return different values.
Args:
timesteps:
samples:
batch_size:
Returns:
MVNDiag distribution of the same shape as `posterior`
"""
loc = tf.tile(tf.expand_dims(self._loc, 0), [timesteps, 1])
scale = tf.expand_dims(self._scale, 0)
if self._offdiag:
scale = tf.tile(scale, [timesteps, 1, 1])
dist = tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale)
else:
scale = tf.tile(scale, [timesteps, 1])
dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
dist = tfd.Independent(dist, reinterpreted_batch_ndims=1)
return dist.sample([samples, batch_size]), dist
VariationalLSTMCell (LSTMCell)
Source code in indl/dists/sequential.py
class VariationalLSTMCell(tfkl.LSTMCell):
def __init__(self, units,
make_dist_fn=None,
make_dist_model=None,
**kwargs):
super(VariationalLSTMCell, self).__init__(units, **kwargs)
self.make_dist_fn = make_dist_fn
self.make_dist_model = make_dist_model
# For some reason the below code doesn't work during build.
# So I don't know how to use the outer VariationalRNN to set this cell's output_size
if self.make_dist_fn is None:
self.make_dist_fn = lambda t: tfd.MultivariateNormalDiag(loc=t[0], scale_diag=t[1])
if self.make_dist_model is None:
fake_cell_output = tfkl.Input((self.units,))
loc = tfkl.Dense(self.output_size, name="VarLSTMCell_loc")(fake_cell_output)
scale = tfkl.Dense(self.output_size, name="VarLSTMCell_scale")(fake_cell_output)
scale = tf.nn.softplus(scale + scale_shift) + 1e-5
dist_layer = tfpl.DistributionLambda(
make_distribution_fn=self.make_dist_fn,
# TODO: convert_to_tensor_fn=lambda s: s.sample(N_SAMPLES)
)([loc, scale])
self.make_dist_model = tf.keras.Model(fake_cell_output, dist_layer)
def build(self, input_shape):
super(VariationalLSTMCell, self).build(input_shape)
# It would be good to defer making self.make_dist_model until here,
# but it doesn't work for some reason.
# def input_zero(self, inputs_):
# input0 = inputs_[..., -1, :]
# input0 = tf.matmul(input0, tf.zeros((input0.shape[-1], self.units)))
# dist0 = self.make_dist_model(input0)
# return dist0
def call(self, inputs, states, training=None):
inputs = tf.convert_to_tensor(inputs)
output, state = super(VariationalLSTMCell, self).call(inputs, states, training=training)
dist = self.make_dist_model(output)
return dist, state
build(self, input_shape)
Creates the variables of the layer (optional, for subclass implementers).
This is a method that implementers of subclasses of Layer
or Model
can override if they need a state-creation step in-between
layer instantiation and layer call.
This is typically used to create the weights of Layer
subclasses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_shape |
Instance of |
required |
Source code in indl/dists/sequential.py
def build(self, input_shape):
super(VariationalLSTMCell, self).build(input_shape)
# It would be good to defer making self.make_dist_model until here,
# but it doesn't work for some reason.
call(self, inputs, states, training=None)
This is where the layer's logic lives.
Note here that call()
method in tf.keras
is little bit different
from keras
API. In keras
API, you can pass support masking for
layers as additional arguments. Whereas tf.keras
has compute_mask()
method to support masking.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Input tensor, or list/tuple of input tensors. |
required | |
**kwargs |
Additional keyword arguments. Currently unused. |
required |
Returns:
Type | Description |
---|---|
A tensor or list/tuple of tensors. |
Source code in indl/dists/sequential.py
def call(self, inputs, states, training=None):
inputs = tf.convert_to_tensor(inputs)
output, state = super(VariationalLSTMCell, self).call(inputs, states, training=training)
dist = self.make_dist_model(output)
return dist, state