![]() |
Runs one step of the Metropolis-Hastings algorithm.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.MetropolisHastings(
inner_kernel, name=None
)
The Metropolis-Hastings algorithm is a Markov chain Monte Carlo (MCMC) technique which uses a proposal distribution to eventually sample from a target distribution.
- have a
target_log_prob
field, - optionally have a
log_acceptance_correction
field, and, - have only fields which are
Tensor
-valued.
The Metropolis-Hastings log acceptance-probability is computed as:
log_accept_ratio = (current_kernel_results.target_log_prob
- previous_kernel_results.target_log_prob
+ current_kernel_results.log_acceptance_correction)
If current_kernel_results.log_acceptance_correction
does not exist, it is presumed 0.
(i.e., that the proposal distribution is symmetric).
The most common use-case for log_acceptance_correction
is in the Metropolis-Hastings algorithm, i.e.,
accept_prob(x' | x) = p(x') / p(x) (g(x|x') / g(x'|x))
where,
p represents the target distribution,
g represents the proposal (conditional) distribution,
x' is the proposed state, and,
x is current state
The log of the parenthetical term is the log_acceptance_correction
.
The log_acceptance_correction
may not necessarily correspond to the ratio of proposal distributions, e.g, log_acceptance_correction
has a different interpretation in Hamiltonian Monte Carlo.
Examples
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
hmc = tfp.mcmc.MetropolisHastings(
tfp.mcmc.UncalibratedHamiltonianMonteCarlo(
target_log_prob_fn=lambda x: -x - x**2,
step_size=0.1,
num_leapfrog_steps=3))
# ==> functionally equivalent to:
# hmc = tfp.mcmc.HamiltonianMonteCarlo(
# target_log_prob_fn=lambda x: -x - x**2,
# step_size=0.1,
# num_leapfrog_steps=3)
Attributes | |
---|---|
experimental_shard_axis_names | The shard axis names for members of the state. |
inner_kernel | |
is_calibrated | Returns True if Markov chain converges to specified distribution.
|
name | |
parameters | Return dict of __init__ arguments and their values. |
Methods
bootstrap_results
bootstrap_results(
init_state
)
Returns an object with the same type as returned by one_step
.
Args | |
---|---|
init_state | Tensor or Python list of Tensor s representing the initial state(s) of the Markov chain(s). |
Returns | |
---|---|
kernel_results | A (possibly nested) tuple , namedtuple or list of Tensor s representing internal calculations made within this function. |
Raises | |
---|---|
ValueError | if inner_kernel results doesn't contain the member "target_log_prob". |
copy
copy(
**override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
Args | |
---|---|
**override_parameter_kwargs | Python String/value dictionary of initialization arguments to override with new values. |
Returns | |
---|---|
new_kernel | TransitionKernel object of same type as self , initialized with the union of self.parameters and override_parameter_kwargs, with any shared keys overridden by the value of override_parameter_kwargs, i.e., dict(self.parameters, **override_parameters_kwargs) . |
experimental_with_shard_axes
experimental_with_shard_axes(
shard_axis_names
)
Returns a copy of the kernel with the provided shard axis names.
Args | |
---|---|
shard_axis_names | a structure of strings indicating the shard axis names for each component of this kernel's state. |
Returns | |
---|---|
A copy of the current kernel with the shard axis information. |
one_step
one_step(
current_state, previous_kernel_results, seed=None
)
Takes one step of the TransitionKernel.
Args | |
---|---|
current_state | Tensor or Python list of Tensor s representing the current state(s) of the Markov chain(s). |
previous_kernel_results | A (possibly nested) tuple , namedtuple or list of Tensor s representing internal calculations made within the previous call to this function (or as returned by bootstrap_results ). |
seed | PRNG seed; see tfp.random.sanitize_seed for details. |
Returns | |
---|---|
next_state | Tensor or Python list of Tensor s representing the next state(s) of the Markov chain(s). |
kernel_results | A (possibly nested) tuple , namedtuple or list of Tensor s representing internal calculations made within this function. |
Raises | |
---|---|
ValueError | if inner_kernel results doesn't contain the member "target_log_prob". |