autoppl  v0.8
A C++ template library for probabilistic programming
mh.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <random>
11 
12 namespace ppl {
13 namespace mcmc {
14 
27 template <class ProgramType
28  , class OffsetPackType
29  , class MCMCResultType>
30 inline void mh_(const ProgramType& program,
31  const MHConfig& config,
32  const OffsetPackType& pack,
33  MCMCResultType& res)
34 {
35  ProgramType program_curr = program; // will be bound to curr
36  ProgramType program_cand = program; // will be bound to cand
37 
38  // data structure to keep track of param candidates
39  using cont_vec_t = Eigen::Matrix<util::cont_param_t, Eigen::Dynamic, 1>;
40  using disc_vec_t = Eigen::Matrix<util::disc_param_t, Eigen::Dynamic, 1>;
41  using visit_vec_t = Eigen::Matrix<size_t, Eigen::Dynamic, 1>;
42 
43  cont_vec_t cont_curr(std::get<0>(pack).uc_offset); // total number of offsets is the number of parameters
44  cont_vec_t cont_cand(std::get<0>(pack).uc_offset);
45  cont_vec_t cont_tp(std::get<0>(pack).tp_offset);
46  cont_vec_t cont_constrained(std::get<0>(pack).c_offset);
47  visit_vec_t cont_visit(std::get<0>(pack).v_offset);
48  cont_tp.setZero();
49  cont_constrained.setZero();
50  cont_visit.setZero();
51 
52  disc_vec_t disc_curr(std::get<1>(pack).uc_offset);
53  disc_vec_t disc_cand(std::get<1>(pack).uc_offset);
54  disc_vec_t disc_tp(std::get<1>(pack).tp_offset);
55 
56  util::cont_ptr_pack_t cont_ptr_pack;
57  cont_ptr_pack.uc_val = cont_curr.data();
58  cont_ptr_pack.c_val = cont_constrained.data();
59  cont_ptr_pack.v_val = cont_visit.data();
60  cont_ptr_pack.tp_val = cont_tp.data();
61 
62  util::disc_ptr_pack_t disc_ptr_pack;
63  disc_ptr_pack.uc_val = disc_curr.data();
64  disc_ptr_pack.tp_val = disc_tp.data();
65 
66  program_curr.bind(cont_ptr_pack);
67  program_curr.bind(disc_ptr_pack);
68 
69  cont_ptr_pack.uc_val = cont_cand.data();
70  disc_ptr_pack.uc_val = disc_cand.data();
71  program_cand.bind(cont_ptr_pack);
72  program_cand.bind(disc_ptr_pack);
73 
74  std::uniform_real_distribution metrop_sampler(0., 1.);
75  std::discrete_distribution disc_sampler({config.alpha, 1-2*config.alpha, config.alpha});
76  std::normal_distribution norm_sampler(0., config.sigma);
77  std::mt19937 gen(config.seed);
78 
79  // references avoid making copies when swapping at the end of for-loop
80  std::reference_wrapper<ProgramType> program_curr_ref(program_curr);
81  std::reference_wrapper<ProgramType> program_cand_ref(program_cand);
82 
83  program_curr_ref.get().init_params(gen, config.prune);
84  double curr_log_pdf = program_curr.log_pdf();
85 
86  // construct miscellaneous objects
87  auto logger = util::ProgressLogger(config.samples + config.warmup, "Metropolis-Hastings");
88  util::StopWatch<> stopwatch_warmup;
89  util::StopWatch<> stopwatch_sampling;
90 
91  // start timing warmup
92  stopwatch_warmup.start();
93 
94  for (size_t iter = 0; iter < config.samples + config.warmup; ++iter) {
95 
96  // if warmup is finished, stop timing warmup and start timing sampling
97  if (iter == config.warmup) {
98  stopwatch_warmup.stop();
99  stopwatch_sampling.start();
100  }
101 
102  logger.printProgress(iter);
103 
104  double log_alpha = -curr_log_pdf;
105 
106  // generate next candidates
107  cont_cand = cont_curr + cont_vec_t::NullaryExpr(cont_cand.size(),
108  [&]() { return norm_sampler(gen); });
109  disc_cand = disc_curr + disc_vec_t::NullaryExpr(disc_cand.size(),
110  [&]() { return disc_sampler(gen) - 1; });
111 
112  // compute next candidate log pdf and update log_alpha
113  double cand_log_pdf = program_cand_ref.get().log_pdf();
114  log_alpha += cand_log_pdf;
115  bool accept = (std::log(metrop_sampler(gen)) <= log_alpha);
116 
117  if (accept) {
118  cont_curr.swap(cont_cand);
119  disc_curr.swap(disc_cand);
120  std::swap(program_curr_ref, program_cand_ref);
121  curr_log_pdf = cand_log_pdf;
122  }
123 
124  if (iter >= config.warmup) {
125  res.cont_samples.row(iter-config.warmup) = cont_curr;
126  res.disc_samples.row(iter-config.warmup) = disc_curr;
127  }
128  }
129 
130  // stop timing sampling
131  stopwatch_sampling.stop();
132 
133  // save output results
134  res.warmup_time = stopwatch_warmup.elapsed();
135  res.sampling_time = stopwatch_sampling.elapsed();
136 }
137 
138 } // namespace mcmc
139 
140 template <class ExprType>
141 inline auto mh(const ExprType& expr,
142  const MHConfig& config = MHConfig())
143 {
144  return mcmc::base_mcmc(expr, config,
145  [](const auto& program, const auto& config,
146  const auto& pack, auto& res) {
147  res.name = "mh";
148  mcmc::mh_(program, config, pack, res);
149  });
150 }
151 
152 } // namespace ppl
ppl::mh
auto mh(const ExprType &expr, const MHConfig &config=MHConfig())
Definition: mh.hpp:141
ppl::util::StopWatch::stop
void stop()
Definition: stopwatch.hpp:11
ppl::mcmc::base_mcmc
MCMCResult base_mcmc(const ExprType &expr, const ConfigType &config, Sampler f)
Definition: base_mcmc.hpp:13
ppl::util::StopWatch::elapsed
double elapsed() const
Definition: stopwatch.hpp:13
ppl::util::PtrPack::c_val
CValPtrType c_val
Definition: ptr_pack.hpp:37
config.hpp
ppl::MHConfig
Definition: config.hpp:7
ppl::MHConfig::sigma
double sigma
Definition: config.hpp:8
ppl::util::StopWatch::start
void start()
Definition: stopwatch.hpp:10
ptr_pack.hpp
ppl::ConfigBase::seed
size_t seed
Definition: config_base.hpp:11
logging.hpp
ppl::mcmc::mh_
void mh_(const ProgramType &program, const MHConfig &config, const OffsetPackType &pack, MCMCResultType &res)
Definition: mh.hpp:30
ppl::util::PtrPack::uc_val
UCValPtrType uc_val
Definition: ptr_pack.hpp:33
ppl::util::StopWatch
Definition: stopwatch.hpp:9
stopwatch.hpp
ppl::util::ProgressLogger
Definition: logging.hpp:19
sampler_tools.hpp
ppl::util::PtrPack::tp_val
TPValPtrType tp_val
Definition: ptr_pack.hpp:35
ppl::ConfigBase::prune
bool prune
Definition: config_base.hpp:12
ppl
Definition: bounded.hpp:11
base_mcmc.hpp
ppl::MHConfig::alpha
double alpha
Definition: config.hpp:9
ppl::util::PtrPack
Definition: ptr_pack.hpp:14
traits.hpp
ppl::ConfigBase::warmup
size_t warmup
Definition: config_base.hpp:9
result.hpp
ppl::ConfigBase::samples
size_t samples
Definition: config_base.hpp:10
ppl::util::PtrPack::v_val
size_t * v_val
Definition: ptr_pack.hpp:38