27 template <
class ProgramType
28 ,
class OffsetPackType
29 ,
class MCMCResultType>
30 inline void mh_(
const ProgramType& program,
32 const OffsetPackType& pack,
35 ProgramType program_curr = program;
36 ProgramType program_cand = program;
39 using cont_vec_t = Eigen::Matrix<util::cont_param_t, Eigen::Dynamic, 1>;
40 using disc_vec_t = Eigen::Matrix<util::disc_param_t, Eigen::Dynamic, 1>;
41 using visit_vec_t = Eigen::Matrix<size_t, Eigen::Dynamic, 1>;
43 cont_vec_t cont_curr(std::get<0>(pack).uc_offset);
44 cont_vec_t cont_cand(std::get<0>(pack).uc_offset);
45 cont_vec_t cont_tp(std::get<0>(pack).tp_offset);
46 cont_vec_t cont_constrained(std::get<0>(pack).c_offset);
47 visit_vec_t cont_visit(std::get<0>(pack).v_offset);
49 cont_constrained.setZero();
52 disc_vec_t disc_curr(std::get<1>(pack).uc_offset);
53 disc_vec_t disc_cand(std::get<1>(pack).uc_offset);
54 disc_vec_t disc_tp(std::get<1>(pack).tp_offset);
57 cont_ptr_pack.
uc_val = cont_curr.data();
58 cont_ptr_pack.
c_val = cont_constrained.data();
59 cont_ptr_pack.
v_val = cont_visit.data();
60 cont_ptr_pack.
tp_val = cont_tp.data();
63 disc_ptr_pack.
uc_val = disc_curr.data();
64 disc_ptr_pack.
tp_val = disc_tp.data();
66 program_curr.bind(cont_ptr_pack);
67 program_curr.bind(disc_ptr_pack);
69 cont_ptr_pack.
uc_val = cont_cand.data();
70 disc_ptr_pack.
uc_val = disc_cand.data();
71 program_cand.bind(cont_ptr_pack);
72 program_cand.bind(disc_ptr_pack);
74 std::uniform_real_distribution metrop_sampler(0., 1.);
75 std::discrete_distribution disc_sampler({config.
alpha, 1-2*config.
alpha, config.
alpha});
76 std::normal_distribution norm_sampler(0., config.
sigma);
77 std::mt19937 gen(config.
seed);
80 std::reference_wrapper<ProgramType> program_curr_ref(program_curr);
81 std::reference_wrapper<ProgramType> program_cand_ref(program_cand);
83 program_curr_ref.get().init_params(gen, config.
prune);
84 double curr_log_pdf = program_curr.log_pdf();
92 stopwatch_warmup.
start();
94 for (
size_t iter = 0; iter < config.
samples + config.
warmup; ++iter) {
97 if (iter == config.
warmup) {
98 stopwatch_warmup.
stop();
99 stopwatch_sampling.
start();
102 logger.printProgress(iter);
104 double log_alpha = -curr_log_pdf;
107 cont_cand = cont_curr + cont_vec_t::NullaryExpr(cont_cand.size(),
108 [&]() { return norm_sampler(gen); });
109 disc_cand = disc_curr + disc_vec_t::NullaryExpr(disc_cand.size(),
110 [&]() { return disc_sampler(gen) - 1; });
113 double cand_log_pdf = program_cand_ref.get().log_pdf();
114 log_alpha += cand_log_pdf;
115 bool accept = (std::log(metrop_sampler(gen)) <= log_alpha);
118 cont_curr.swap(cont_cand);
119 disc_curr.swap(disc_cand);
120 std::swap(program_curr_ref, program_cand_ref);
121 curr_log_pdf = cand_log_pdf;
124 if (iter >= config.
warmup) {
125 res.cont_samples.row(iter-config.
warmup) = cont_curr;
126 res.disc_samples.row(iter-config.
warmup) = disc_curr;
131 stopwatch_sampling.
stop();
134 res.warmup_time = stopwatch_warmup.
elapsed();
135 res.sampling_time = stopwatch_sampling.
elapsed();
140 template <
class ExprType>
141 inline auto mh(
const ExprType& expr,
145 [](
const auto& program,
const auto& config,
146 const auto& pack,
auto& res) {