MCMC in EasyVVUQΒΆ
EasyVVUQ provides support for MCMC sampling with multiple chains in parallel.
[ ]:
import os
import easyvvuq as uq
import numpy as np
import chaospy as cp
import json
import matplotlib.pyplot as plt
import sys
We define a Rosenbrock function in 2 dimensions for testing purposes. This will be a stand-in for our probability density.
[ ]:
def rosenbrock(directory):
json_input = os.path.join(directory, 'input.json')
if not os.path.isfile(json_input):
sys.exit(json_input + " does not exist.")
with open(json_input, "r") as fd:
inputs = json.load(fd)
x1 = float(inputs['x1'])
x2 = float(inputs['x2'])
output_filename = os.path.join(directory, inputs['outfile'])
y = (1.0 - x1) ** 2 + 100.0 * (x2 - x1 ** 2) ** 2
with open(output_filename, 'w') as fd:
json.dump({'value': -y}, fd)
Next we define a helper function to create a campaign, sample the search space and return the corresponding DataFrame.
[ ]:
def mcmc(tmp_path='.'):
campaign = uq.Campaign(name="mcmc", work_dir=tmp_path)
params = {
"x1": {"type": "float", "default": 0.0},
"x2": {"type": "float", "default": 0.0},
"out_file": {"type": "string", "default": "output.json"},
"chain_id": {"type": "integer", "default": 0}
}
encoder = uq.encoders.GenericEncoder(
template_fname=os.path.abspath("rosenbrock.template"), delimiter="$", target_filename="input.json")
decoder = uq.decoders.JSONDecoder("output.json", ["value"])
campaign.add_app(name="mcmc", params=params, encoder=encoder, decoder=decoder)
vary_init = {
"x1": [-1.0, 0.0, 1.0, 0.5, 0.1],
"x2": [1.0, 0.0, 0.5, 1.0, 0.2]
}
def q(x, b=1):
return cp.J(cp.Normal(x['x1'], b), cp.Normal(x['x2'], b))
sampler = uq.sampling.MCMCSampler(vary_init, q, 'value', n_chains=5)
campaign.set_sampler(sampler)
action = uq.actions.ExecutePython(rosenbrock)
iterator = campaign.iterate(action, mark_invalid=True)
for _ in range(1000):
next(iterator).start()
df = campaign.get_collation_result()
return df
[ ]:
df = mcmc()
Finally we plot the the five different chains.
[7]:
import matplotlib.pyplot as plt
plt.plot(df[(df['chain_id'] == 0).values]['x1'], df[(df['chain_id'] == 0).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 1).values]['x1'], df[(df['chain_id'] == 1).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 2).values]['x1'], df[(df['chain_id'] == 2).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 3).values]['x1'], df[(df['chain_id'] == 3).values]['x2'], alpha=0.5)
plt.plot(df[(df['chain_id'] == 4).values]['x1'], df[(df['chain_id'] == 4).values]['x2'], alpha=0.5)
[7]:
[<matplotlib.lines.Line2D at 0x124a943a0>]
Finally let us plot a histogram of this data.
[6]:
hist = plt.hist2d(df['x1'].T.values[0], df['x2'].T.values[0], bins=20)
[ ]:
[ ]: