Source code for easyvvuq.sampling.mcmc
from .base import BaseSamplingElement
import numpy as np
[docs]class MCMCSampler(BaseSamplingElement, sampler_name='mcmc_sampler'):
"""A Metropolis-Hastings MCMC Sampler.
Parameters
----------
init: dict
Initial values for each input parameter. Of the form {'input1': value, ...}
q: function
A function of one argument X (dictionary) that returns the proposal distribution conditional on
the `X`.
qoi: str
Name of the quantity of interest
n_chains: int
Number of MCMC chains to run in paralle.
estimator: function
To be used with replica_col argument. Outputs an estimate of some
parameter when given a sample array.
"""
def __init__(self, init, q, qoi, n_chains=1, likelihood=lambda x: x[0], estimator=None):
for param in init:
if not hasattr(init[param], '__iter__'):
raise RuntimeError(
'all input intializations should be iterables of same length as there are chains')
if len(init[param]) != n_chains:
raise RuntimeError(
'initialization dictionary should have separate values for each chain')
self.init = dict(init)
self.inputs = list(self.init.keys())
for input_ in self.inputs:
if len(self.init[input_]) != n_chains:
raise RuntimeError("The init dictionary must contains the same number \
of values for each input as there are chains.")
self.n_chains = n_chains
self.x = []
self.q = q
self.qoi = qoi
self.current_chain = 0
for chain in range(self.n_chains):
self.x.append(dict([(key, self.init[key][chain]) for key in self.inputs]))
self.x[chain]['chain_id'] = chain
self.f_x = [None] * n_chains
self.stop = False
self.likelihood = lambda x: np.exp(likelihood(x))
self.n_replicas = None
self.estimator = estimator
self.acceptance_ratios = []
self.iteration = 0
[docs] def is_finite(self):
return True
[docs] def n_samples(self):
return self.n_chains
def __iter__(self):
self.current_chain = 0
return self
def __next__(self):
"""Returns next MCMC sample.
Returns
-------
dict
A dictionary where keys are input variables names and values are input values.
"""
if self.stop:
raise StopIteration
if self.f_x[self.current_chain] is None:
try:
return self.x[self.current_chain]
finally:
self.current_chain = (self.current_chain + 1) % self.n_chains
if self.current_chain == 0:
self.stop = True
y = {}
y_ = self.q(self.x[self.current_chain]).sample()
for i, key in enumerate(self.inputs):
y[key] = y_[i][0]
y['chain_id'] = self.current_chain
self.current_chain = (self.current_chain + 1) % self.n_chains
if self.current_chain == 0:
self.stop = True
return y
[docs] def update(self, result, invalid):
"""Performs the MCMC sampling procedure on the campaign.
Parameters
----------
result: pandas DataFrame
run information from previous iteration (same as collation DataFrame)
invalid: pandas DataFrame
invalid run information (runs that cannot be executed for some reason)
Returns
-------
list of rejected run ids
"""
self.stop = False
if (self.estimator is not None) and (len(result) > 0):
result_grouped = result.groupby(('chain_id', 0)).apply(self.estimator)
else:
result_grouped = result
if (self.estimator is not None) and (len(invalid) > 0):
invalid_grouped = invalid.groupby(('chain_id', 0)).apply(lambda x: x.mean())
else:
invalid_grouped = invalid
ignored_chains = []
ignored_runs = []
# process normal runs
for row in result_grouped.iterrows():
row = row[1]
chain_id = int(row['chain_id'].values[0])
y = dict([(key, row[key][0]) for key in self.inputs])
if self.f_x[chain_id] is None:
self.f_x[chain_id] = self.likelihood(row[self.qoi].values)
else:
f_y = self.likelihood(row[self.qoi].values)
q_xy = self.q(y).pdf([self.x[chain_id][key] for key in self.inputs])
q_yx = self.q(self.x[chain_id]).pdf([y[key] for key in self.inputs])
if self.f_x[chain_id] == 0.0:
r = 1.0
else:
r = min(1.0, (f_y / self.f_x[chain_id]) * (q_xy / q_yx))
if np.random.random() < r:
self.x[chain_id] = dict(y)
self.f_x[chain_id] = f_y
else:
ignored_chains.append(chain_id)
for row in invalid_grouped.iterrows():
row = row[1]
chain_id = int(row['chain_id'].values[0])
ignored_chains.append(chain_id)
for chain_id in ignored_chains:
try:
ignored_runs += list(result.loc[result[('chain_id', 0)]
== chain_id]['run_id'].values)
except KeyError:
pass
try:
ignored_runs += list(invalid.loc[invalid[('chain_id', 0)]
== chain_id]['run_id'].values)
except KeyError:
pass
ignored_runs = [run[0] for run in ignored_runs]
self.iteration += 1
return ignored_runs
@property
def analysis_class(self):
"""Returns a corresponding analysis class for this sampler.
Returns
-------
class
"""
from easyvvuq.analysis import MCMCAnalysis
return MCMCAnalysis