![]() |
Implements a continuous normalizing flow X->Y defined via an ODE.
Inherits From: Bijector
tfp.bijectors.FFJORD( state_time_derivative_fn, ode_solve_fn=None, trace_augmentation_fn=trace_jacobian_hutchinson, initial_time=0.0, final_time=1.0, validate_args=False, dtype=tf.float32, name='ffjord' )
This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.
d/dt[state(t)]=state_time_derivative_fn(t, state(t)) state(initial_time) = X state(final_time) = Y
For this transformation the value of log_det_jacobian
follows another differential equation, reducing it to computation of the trace of the jacbian along the trajectory
state_time_derivative = state_time_derivative_fn(t, state(t)) d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))
FFJORD constructor takes two functions ode_solve_fn
and trace_augmentation_fn
arguments that customize integration of the differential equation and trace estimation.
Differential equation integration is performed by a call to ode_solve_fn
. Custom ode_solve_fn
must accept the following arguments:
- ode_fn(time, state, **condition_kwargs): Differential equation to be solved. Custom
ode_solve_fn
s may optionally support conditional inputs by accepting aconstants
dict arg and computing gradients wrt the provided values in **condition_kwargs. - initial_time: Scalar float or floating Tensor representing the initial time.
- initial_state: Floating Tensor representing the initial state.
- solution_times: 1D floating Tensor of solution times.
And return a Tensor of shape [solution_times.shape, initial_state.shape] representing state values evaluated at solution_times
. In addition ode_solve_fn
must support nested structures. For more details see the interface of tfp.math.ode.Solver.solve()
.
Trace estimation is computed simultaneously with state_time_derivative
using augmented_state_time_derivative_fn
that is generated by trace_augmentation_fn
. trace_augmentation_fn
takes state_time_derivative_fn
, state.shape
and state.dtype
arguments and returns a augmented_state_time_derivative_fn
callable that computes both state_time_derivative
and unreduced trace_estimation
.
Custom ode_solve_fn
and trace_augmentation_fn
examples:
# custom_solver_fn: `callable(f, t_initial, t_solutions, y_initial, ...)` # custom_solver_kwargs: Additional arguments to pass to custom_solver_fn. def ode_solve_fn(ode_fn, initial_time, initial_state, solution_times): results = custom_solver_fn(ode_fn, initial_time, solution_times, initial_state, **custom_solver_kwargs) return results ffjord = tfb.FFJORD(state_time_derivative_fn, ode_solve_fn=ode_solve_fn)
# state_time_derivative_fn: `callable(time, state)` # trace_jac_fn: `callable(time, state)` unreduced jacobian trace function def trace_augmentation_fn(ode_fn, state_shape, state_dtype): def augmented_ode_fn(time, state): return ode_fn(time, state), trace_jac_fn(time, state) return augmented_ode_fn ffjord = tfb.FFJORD(state_time_derivative_fn, trace_augmentation_fn=trace_augmentation_fn)
For more details on FFJORD and continous normalizing flows see [1], [2].
Usage example:
tfd = tfp.distributions tfb = tfp.bijectors # state_time_derivative_fn: `Callable(time, state)` -> state_time_derivative # e.g. Neural network with inputs and outputs of the same shapes and dtypes. bijector = tfb.FFJORD(state_time_derivative_fn=state_time_derivative_fn) y = bijector.forward(x) # forward mapping x = bijector.inverse(y) # inverse mapping base = tfd.Normal(tf.zeros_like(x), tf.ones_like(x)) # Base distribution transformed_distribution = tfd.TransformedDistribution(base, bijector)
References
[1]: Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)
[2]: Grathwohl, W., Chen, R. T., Betterncourt, J., Sutskever, I., & Duvenaud, D. (2018). Ffjord: Free-form continuous dynamics for scalable reversible generative models. arXiv preprint arXiv:1810.01367. http://arxiv.org.abs/1810.01367
Args | |
---|---|
state_time_derivative_fn | Python callable taking arguments time (a scalar representing time) and state (a Tensor representing the state at given time ) returning the time derivative of the state at given time . |
ode_solve_fn | Python callable taking arguments ode_fn (same as state_time_derivative_fn above), initial_time (a scalar representing the initial time of integration), initial_state (a Tensor of floating dtype represents the initial state) and solution_times (1D Tensor of floating dtype representing time at which to obtain the solution) returning a Tensor of shape [time_axis, initial_state.shape]. Will take [final_time] as the solution_times argument and state_time_derivative_fn as ode_fn argument. For details on providing custom ode_solve_fn see class docstring. If None a DormandPrince solver from tfp.math.ode is used. Default value: None |
trace_augmentation_fn | Python callable taking arguments ode_fn ( python callable same as state_time_derivative_fn above), state_shape (TensorShape of a the state), dtype (same as dtype of the state) and returning a python callable taking arguments time (a scalar representing the time at which the function is evaluted), state (a Tensor representing the state at given time ) that computes a tuple (ode_fn(time, state) , jacobian_trace_estimation ). jacobian_trace_estimation should represent trace of the jacobian of ode_fn with respect to state . state_time_derivative_fn will be passed as ode_fn argument. For details on providing custom trace_augmentation_fn see class docstring. Default value: tfp.bijectors.ffjord.trace_jacobian_hutchinson |
initial_time | Scalar float representing time to which the x value of the bijector corresponds to. Passed as initial_time to ode_solve_fn . For default solver can be Python float or floating scalar Tensor . Default value: 0. |
final_time | Scalar float representing time to which the y value of the bijector corresponds to. Passed as solution_times to ode_solve_fn . For default solver can be Python float or floating scalar Tensor . Default value: 1. |
validate_args | Python 'bool' indicating whether to validate input. Default value: False |
dtype | tf.DType to prefer when converting args to Tensor s. Else, we fall back to a common dtype inferred from the args, finally falling back to float32. |
name | Python str name prefixed to Ops created by this function. |
Attributes | |
---|---|
dtype | |
forward_min_event_ndims | Returns the minimal number of dimensions bijector.forward operates on. Multipart bijectors return structured |
graph_parents | Returns this Bijector 's graph_parents as a Python list. |
inverse_min_event_ndims | Returns the minimal number of dimensions bijector.inverse operates on. Multipart bijectors return structured |
is_constant_jacobian | Returns true iff the Jacobian matrix is not a function of x. |
name | Returns the string name of this Bijector . |
name_scope | Returns a tf.name_scope instance for this class. |
non_trainable_variables | Sequence of non-trainable variables owned by this module and its submodules. |
parameters | Dictionary of parameters used to instantiate this Bijector . |
submodules | Sequence of all sub-modules. Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables | Sequence of trainable variables owned by this module and its submodules. |
validate_args | Returns True if Tensor arguments will be validated. |
variables | Sequence of variables owned by this module and its submodules. |
Methods
copy
copy( **override_parameters_kwargs )
Creates a copy of the bijector.
Args | |
---|---|
**override_parameters_kwargs | String/value dictionary of initialization arguments to override with new values. |
Returns | |
---|---|
bijector | A new instance of type(self) initialized from the union of self.parameters and override_parameters_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs) . |
experimental_batch_shape
experimental_batch_shape( x_event_ndims=None, y_event_ndims=None )
Returns the batch shape of this bijector for inputs of the given rank.
The batch shape of a bijector decribes the set of distinct transformations it represents on events of a given size. For example: the bijector tfb.Scale([1., 2.])
has batch shape [2]
for scalar events (event_ndims = 0
), because applying it to a scalar event produces two scalar outputs, the result of two different scaling transformations. The same bijector has batch shape []
for vector events, because applying it to a vector produces (via elementwise multiplication) a single vector output.
Bijectors that operate independently on multiple state parts, such as tfb.JointMap
, must broadcast to a coherent batch shape. Some events may not be valid: for example, the bijector tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])])
does not produce a valid batch shape when event_ndims = [0, 0]
, since the batch shapes of the two parts are inconsistent. The same bijector does define valid batch shapes of []
, [2]
, and [3]
if event_ndims
is [1, 1]
, [0, 1]
, or [1, 0]
, respectively.
Since transforming a single event produces a scalar log-det-Jacobian, the batch shape of a bijector with non-constant Jacobian is expected to equal the shape of forward_log_det_jacobian(x, event_ndims=x_event_ndims)
or inverse_log_det_jacobian(y, event_ndims=y_event_ndims)
, for x
or y
of the specified ndims
.
Args | |
---|---|
x_event_ndims | Optional Python int (structure) number of dimensions in a probabilistic event passed to forward ; this must be greater than or equal to self.forward_min_event_ndims . If None , defaults to self.forward_min_event_ndims . Mutually exclusive with y_event_ndims . Default value: None . |
y_event_ndims | Optional Python int (structure) number of dimensions in a probabilistic event passed to inverse ; this must be greater than or equal to self.inverse_min_event_ndims . Mutually exclusive with x_event_ndims . Default value: None . |
Returns | |
---|---|
batch_shape | TensorShape batch shape of this bijector for a value with the given event rank. May be unknown or partially defined. |
experimental_batch_shape_tensor
experimental_batch_shape_tensor( x_event_ndims=None, y_event_ndims=None )
Returns the batch shape of this bijector for inputs of the given rank.
The batch shape of a bijector decribes the set of distinct transformations it represents on events of a given size. For example: the bijector tfb.Scale([1., 2.])
has batch shape [2]
for scalar events (event_ndims = 0
), because applying it to a scalar event produces two scalar outputs, the result of two different scaling transformations. The same bijector has batch shape []
for vector events, because applying it to a vector produces (via elementwise multiplication) a single vector output.
Bijectors that operate independently on multiple state parts, such as tfb.JointMap
, must broadcast to a coherent batch shape. Some events may not be valid: for example, the bijector tfd.JointMap([tfb.Scale([1., 2.]), tfb.Scale([1., 2., 3.])])
does not produce a valid batch shape when event_ndims = [0, 0]
, since the batch shapes of the two parts are inconsistent. The same bijector does define valid batch shapes of []
, [2]
, and [3]
if event_ndims
is [1, 1]
, [0, 1]
, or [1, 0]
, respectively.
Since transforming a single event produces a scalar log-det-Jacobian, the batch shape of a bijector with non-constant Jacobian is expected to equal the shape of forward_log_det_jacobian(x, event_ndims=x_event_ndims)
or inverse_log_det_jacobian(y, event_ndims=y_event_ndims)
, for x
or y
of the specified ndims
.
Args | |
---|---|
x_event_ndims | Optional Python int (structure) number of dimensions in a probabilistic event passed to forward ; this must be greater than or equal to self.forward_min_event_ndims . If None , defaults to self.forward_min_event_ndims . Mutually exclusive with y_event_ndims . Default value: None . |
y_event_ndims | Optional Python int (structure) number of dimensions in a probabilistic event passed to inverse ; this must be greater than or equal to self.inverse_min_event_ndims . Mutually exclusive with x_event_ndims . Default value: None . |
Returns | |
---|---|
batch_shape_tensor | integer Tensor batch shape of this bijector for a value with the given event rank. |
experimental_compute_density_correction
experimental_compute_density_correction( x, tangent_space, backward_compat=False, **kwargs )
Density correction for this transformation wrt the tangent space, at x.
Subclasses of Bijector may call the most specific applicable method of TangentSpace
, based on whether the transformation is dimension-preserving, coordinate-wise, a projection, or something more general. The backward-compatible assumption is that the transformation is dimension-preserving (goes from R^n to R^n).
Args | |
---|---|
x | Tensor (structure). The point at which to calculate the density. |
tangent_space | TangentSpace or one of its subclasses. The tangent to the support manifold at x . |
backward_compat | bool specifying whether to assume that the Bijector is dimension-preserving. |
**kwargs | Optional keyword arguments forwarded to tangent space methods. |
Returns | |
---|---|
density_correction | Tensor representing the density correction---in log space---under the transformation that this Bijector denotes. |
Raises | |
---|---|
TypeError if backward_compat is False but no method of TangentSpace has been called explicitly. |
forward
forward( x, name='forward', **kwargs )
Returns the forward Bijector
evaluation, i.e., X = g(Y).
Args | |
---|---|
x | Tensor (structure). The input to the 'forward' evaluation. |
name | The name to give this op. |
**kwargs | Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure). |
Raises | |
---|---|
TypeError | if self.dtype is specified and x.dtype is not self.dtype . |
NotImplementedError | if _forward is not implemented. |
forward_dtype
forward_dtype( dtype=UNSPECIFIED, name='forward_dtype', **kwargs )
Returns the dtype returned by forward
for the provided input.
forward_event_ndims
forward_event_ndims( event_ndims, **kwargs )
Returns the number of event dimensions produced by forward
.
Args | |
---|---|
event_ndims | Structure of Python and/or Tensor int s, and/or None values. The structure should match that of self.forward_min_event_ndims , and all non-None values must be greater than or equal to the corresponding value in self.forward_min_event_ndims . |
**kwargs | Optional keyword arguments forwarded to nested bijectors. |
Returns | |
---|---|
forward_event_ndims | Structure of integers and/or None values matching self.inverse_min_event_ndims . These are computed using 'prefer static' semantics: if any inputs are None , some or all of the outputs may be None , indicating that the output dimension could not be inferred (conversely, if all inputs are non-None , all outputs will be non-None ). If all input event_ndims are Python int s, all of the (non-None ) outputs will be Python int s; otherwise, some or all of the outputs may be Tensor int s. |
forward_event_shape
forward_event_shape( input_shape )
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as forward_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
input_shape | TensorShape (structure) indicating event-portion shape passed into forward function. |
Returns | |
---|---|
forward_event_shape_tensor | TensorShape (structure) indicating event-portion shape after applying forward . Possibly unknown. |
forward_event_shape_tensor
forward_event_shape_tensor( input_shape, name='forward_event_shape_tensor' )
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
input_shape | Tensor , int32 vector (structure) indicating event-portion shape passed into forward function. |
name | name to give to the op |
Returns | |
---|---|
forward_event_shape_tensor | Tensor , int32 vector (structure) indicating event-portion shape after applying forward . |
forward_log_det_jacobian
forward_log_det_jacobian( x, event_ndims=None, name='forward_log_det_jacobian', **kwargs )
Returns both the forward_log_det_jacobian.
Args | |
---|---|
x | Tensor (structure). The input to the 'forward' Jacobian determinant evaluation. |
event_ndims | Optional number of dimensions in the probabilistic events being transformed; this must be greater than or equal to self.forward_min_event_ndims . If event_ndims is specified, the log Jacobian determinant is summed to produce a scalar log-determinant for each event. Otherwise (if event_ndims is None ), no reduction is performed. Multipart bijectors require structured event_ndims, such that the batch rank rank(y[i]) - event_ndims[i] is the same for all elements i of the structured input. In most cases (with the exception of tfb.JointMap ) they further require that event_ndims[i] - self.inverse_min_event_ndims[i] is the same for all elements i of the structured input. Default value: None (equivalent to self.forward_min_event_ndims ). |
name | The name to give this op. |
**kwargs | Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective. If not injective this is not implemented. |
Raises | |
---|---|
TypeError | if y 's dtype is incompatible with the expected output dtype. |
NotImplementedError | if neither _forward_log_det_jacobian nor {_inverse , _inverse_log_det_jacobian } are implemented, or this is a non-injective bijector. |
ValueError | if the value of event_ndims is not valid for this bijector. |
inverse
inverse( y, name='inverse', **kwargs )
Returns the inverse Bijector
evaluation, i.e., X = g^{-1}(Y).
Args | |
---|---|
y | Tensor (structure). The input to the 'inverse' evaluation. |
name | The name to give this op. |
**kwargs | Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective. If not injective, returns the k-tuple containing the unique k points (x1, ..., xk) such that g(xi) = y . |
Raises | |
---|---|
TypeError | if y 's structured dtype is incompatible with the expected output dtype. |
NotImplementedError | if _inverse is not implemented. |
inverse_dtype
inverse_dtype( dtype=UNSPECIFIED, name='inverse_dtype', **kwargs )
Returns the dtype returned by inverse
for the provided input.
inverse_event_ndims
inverse_event_ndims( event_ndims, **kwargs )
Returns the number of event dimensions produced by inverse
.
Args | |
---|---|
event_ndims | Structure of Python and/or Tensor int s, and/or None values. The structure should match that of self.inverse_min_event_ndims , and all non-None values must be greater than or equal to the corresponding value in self.inverse_min_event_ndims . |
**kwargs | Optional keyword arguments forwarded to nested bijectors. |
Returns | |
---|---|
inverse_event_ndims | Structure of integers and/or None values matching self.forward_min_event_ndims . These are computed using 'prefer static' semantics: if any inputs are None , some or all of the outputs may be None , indicating that the output dimension could not be inferred (conversely, if all inputs are non-None , all outputs will be non-None ). If all input event_ndims are Python int s, all of the (non-None ) outputs will be Python int s; otherwise, some or all of the outputs may be Tensor int s. |
inverse_event_shape
inverse_event_shape( output_shape )
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as inverse_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
output_shape | TensorShape (structure) indicating event-portion shape passed into inverse function. |
Returns | |
---|---|
inverse_event_shape_tensor | TensorShape (structure) indicating event-portion shape after applying inverse . Possibly unknown. |
inverse_event_shape_tensor
inverse_event_shape_tensor( output_shape, name='inverse_event_shape_tensor' )
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
output_shape | Tensor , int32 vector (structure) indicating event-portion shape passed into inverse function. |
name | name to give to the op |
Returns | |
---|---|
inverse_event_shape_tensor | Tensor , int32 vector (structure) indicating event-portion shape after applying inverse . |
inverse_log_det_jacobian
inverse_log_det_jacobian( y, event_ndims=None, name='inverse_log_det_jacobian', **kwargs )
Returns the (log o det o Jacobian o inverse)(y).
Mathematically, returns: log(det(dX/dY))(Y)
. (Recall that: X=g^{-1}(Y)
.)
Note that forward_log_det_jacobian
is the negative of this function, evaluated at g^{-1}(y)
.
Args | |
---|---|
y | Tensor (structure). The input to the 'inverse' Jacobian determinant evaluation. |
event_ndims | Optional number of dimensions in the probabilistic events being transformed; this must be greater than or equal to self.inverse_min_event_ndims . If event_ndims is specified, the log Jacobian determinant is summed to produce a scalar log-determinant for each event. Otherwise (if event_ndims is None ), no reduction is performed. Multipart bijectors require structured event_ndims, such that the batch rank rank(y[i]) - event_ndims[i] is the same for all elements i of the structured input. In most cases (with the exception of tfb.JointMap ) they further require that event_ndims[i] - self.inverse_min_event_ndims[i] is the same for all elements i of the structured input. Default value: None (equivalent to self.inverse_min_event_ndims ). |
name | The name to give this op. |
**kwargs | Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
ildj | Tensor , if this bijector is injective. If not injective, returns the tuple of local log det Jacobians, log(det(Dg_i^{-1}(y))) , where g_i is the restriction of g to the ith partition Di . |
Raises | |
---|---|
TypeError | if x 's dtype is incompatible with the expected inverse-dtype. |
NotImplementedError | if _inverse_log_det_jacobian is not implemented. |
ValueError | if the value of event_ndims is not valid for this bijector. |
parameter_properties
@classmethod
parameter_properties( dtype=tf.float32 )
Returns a dict mapping constructor arg names to property annotations.
This dict should include an entry for each of the bijector's Tensor
-valued constructor arguments.
Args | |
---|---|
dtype | Optional float dtype to assume for continuous-valued parameters. Some constraining bijectors require advance knowledge of the dtype because certain constants (e.g., tfb.Softplus.low ) must be instantiated with the same dtype as the values to be transformed. |
Returns | |
---|---|
parameter_properties | A str -> tfp.python.internal.parameter_properties.ParameterPropertiesdict mapping constructor argument names to ParameterProperties` instances. |
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method | The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__call__
__call__( value, name=None, **kwargs )
Applies or composes the Bijector
, depending on input type.
This is a convenience function which applies the Bijector
instance in three different ways, depending on the input:
- If the input is a
tfd.Distribution
instance, returntfd.TransformedDistribution(distribution=input, bijector=self)
. - If the input is a
tfb.Bijector
instance, returntfb.Chain([self, input])
. - Otherwise, return
self.forward(input)
Args | |
---|---|
value | A tfd.Distribution , tfb.Bijector , or a (structure of) Tensor . |
name | Python str name given to ops created by this function. |
**kwargs | Additional keyword arguments passed into the created tfd.TransformedDistribution , tfb.Bijector , or self.forward . |
Returns | |
---|---|
composition | A tfd.TransformedDistribution if the input was a tfd.Distribution , a tfb.Chain if the input was a tfb.Bijector , or a (structure of) Tensor computed by self.forward . |
Examples
sigmoid = tfb.Reciprocal()( tfb.Shift(shift=1.)( tfb.Exp()( tfb.Scale(scale=-1.)))) # ==> `tfb.Chain([ # tfb.Reciprocal(), # tfb.Shift(shift=1.), # tfb.Exp(), # tfb.Scale(scale=-1.), # ])` # ie, `tfb.Sigmoid()` log_normal = tfb.Exp()(tfd.Normal(0, 1)) # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())` tfb.Exp()([-1., 0., 1.]) # ==> tf.exp([-1., 0., 1.])
__eq__
__eq__( other )
Return self==value.
__getitem__
__getitem__( slices )
__iter__
__iter__()