4 #include <fastad_bits/reverse/core/var_view.hpp>
5 #include <fastad_bits/reverse/core/eval.hpp>
32 template <
class MatType1,
class MatType2,
class MatType3>
34 const MatType2& p_beg_scaled,
35 const MatType3& p_end_scaled)
37 return rho.dot(p_beg_scaled) > 0 &&
38 rho.dot(p_end_scaled) > 0;
60 template <
class InputType
61 ,
class UniformDistType
63 ,
class MomentumHandlerType
68 UniformDistType& unif_sampler,
70 MomentumHandlerType& momentum_handler,
73 constexpr
double delta_max = 1000;
77 double new_potential =
leapfrog(input.ad_expr_ref.get(),
78 input.theta_ref.get(),
79 input.theta_adj_ref.get(),
80 input.tp_adj_ref.get(),
81 input.p_most_ref.get(),
83 input.v * input.epsilon,
86 double new_kinetic = momentum_handler.kinetic(input.p_most_ref.get());
87 double new_ham =
hamiltonian(new_potential, new_kinetic);
90 ++(input.n_leapfrog_ref.get());
93 if (std::isnan(new_ham)) { new_ham = math::inf<double>; }
94 input.log_sum_weight_ref.get() =
math::lse(
95 input.log_sum_weight_ref.get(),
99 input.sum_metro_prob_ref.get() += (input.ham - new_ham > 0) ?
100 1 : std::exp(input.ham - new_ham);
103 input.theta_prime_ref.get() = input.theta_ref.get();
106 input.p_beg_ref.get() = input.p_most_ref.get();
107 input.p_beg_scaled_ref.get() =
108 momentum_handler.dkinetic_dr(input.p_most_ref.get());
111 input.p_end_ref.get() = input.p_beg_ref.get();
112 input.p_end_scaled_ref.get() = input.p_beg_scaled_ref.get();
115 input.rho_ref.get() += input.p_most_ref.get();
119 (new_ham - input.ham <= delta_max),
125 Eigen::Map<Eigen::VectorXd> p_end_inner(tree_cache, n_params);
126 Eigen::Map<Eigen::VectorXd> p_end_scaled_inner(tree_cache + n_params, n_params);
127 Eigen::Map<Eigen::VectorXd> rho_first(tree_cache + 2*n_params, n_params);
129 double log_sum_weight_first = math::neg_inf<double>;
131 tree_cache += 3 * n_params;
135 InputType first_input = input;
136 first_input.p_end_ref = p_end_inner;
137 first_input.p_end_scaled_ref = p_end_scaled_inner;
138 first_input.rho_ref = rho_first;
139 first_input.log_sum_weight_ref = log_sum_weight_first;
144 unif_sampler, gen, momentum_handler,
150 if (!first_output.
valid) {
return first_output; }
153 Eigen::Map<Eigen::VectorXd> theta_double_prime(tree_cache, n_params);
154 Eigen::Map<Eigen::VectorXd> p_beg_inner(tree_cache + n_params, n_params);
155 Eigen::Map<Eigen::VectorXd> p_beg_scaled_inner(tree_cache + 2*n_params, n_params);
156 Eigen::Map<Eigen::VectorXd> rho_second(tree_cache + 3*n_params, n_params);
157 rho_second.setZero();
158 double log_sum_weight_second = math::neg_inf<double>;
160 tree_cache += 4 * n_params;
163 InputType second_input = input;
164 second_input.theta_prime_ref = theta_double_prime;
165 second_input.p_beg_ref = p_beg_inner;
166 second_input.p_beg_scaled_ref = p_beg_scaled_inner;
167 second_input.rho_ref = rho_second;
168 second_input.log_sum_weight_ref = log_sum_weight_second;
173 unif_sampler, gen, momentum_handler,
179 if (!second_output.
valid) {
180 first_output.
valid =
false;
189 log_sum_weight_first, log_sum_weight_second
191 input.log_sum_weight_ref.get() =
math::lse(
192 input.log_sum_weight_ref.get(), log_sum_weight_curr
196 double accept_prob = std::exp(log_sum_weight_second - log_sum_weight_curr);
199 input.theta_prime_ref.get() =
200 second_input.theta_prime_ref.get();
208 auto rho_curr = rho_first + rho_second;
209 input.rho_ref.get() += rho_curr;
212 input.p_beg_scaled_ref.get(),
213 input.p_end_scaled_ref.get()) &&
215 input.p_beg_scaled_ref.get(),
216 p_beg_scaled_inner) &&
219 input.p_end_scaled_ref.get());
234 template <
class ADExprType
237 ,
class MomentumHandlerType>
244 MomentumHandlerType& momentum_handler)
247 if (eps <= 0 || eps > 1e7)
return eps;
249 const double diff_bound = std::log(0.8);
251 size_t n_params = theta.rows();
253 Eigen::MatrixXd
mat(n_params, 3);
254 Eigen::Map<Eigen::VectorXd> r(
mat.col(0).data(), n_params);
255 Eigen::Map<Eigen::VectorXd> theta_orig(
mat.col(1).data(), n_params);
256 Eigen::Map<Eigen::VectorXd> theta_adj_orig(
mat.col(2).data(), n_params);
259 momentum_handler.sample(r, gen);
262 const double potential_orig = -ad::autodiff(ad_expr);
263 double kinetic_orig = momentum_handler.kinetic(r);
264 double ham_orig =
hamiltonian(potential_orig, kinetic_orig);
268 theta_adj_orig = theta_adj;
272 ad_expr, theta, theta_adj, tp_adj,
273 r, momentum_handler, eps,
true);
274 double kinetic_curr = momentum_handler.kinetic(r);
275 double ham_curr =
hamiltonian(potential_curr, kinetic_curr);
277 int a = (ham_orig - ham_curr > diff_bound) ? 1 : -1;
282 if ( ((a == 1) && !(ham_orig - ham_curr > diff_bound)) ||
283 ((a == -1) && !(ham_orig - ham_curr < diff_bound)) ) {
288 eps *= (a == -1) ? 0.5 : 2;
292 theta_adj = theta_adj_orig;
295 momentum_handler.sample(r, gen);
296 kinetic_orig = momentum_handler.kinetic(r);
297 ham_orig =
hamiltonian(potential_orig, kinetic_orig);
301 ad_expr, theta, theta_adj, tp_adj,
302 r, momentum_handler, eps,
true);
303 kinetic_curr = momentum_handler.kinetic(r);
304 ham_curr =
hamiltonian(potential_curr, kinetic_curr);
310 theta_adj = theta_adj_orig;
331 template <
class ProgramType
332 ,
class OffsetPackType
333 ,
class MCMCResultType
336 const NUTSConfigType& config,
337 const OffsetPackType& pack,
340 assert(std::get<1>(pack).uc_offset == 0);
341 assert(std::get<1>(pack).tp_offset == 0);
342 assert(std::get<1>(pack).c_offset == 0);
343 assert(std::get<1>(pack).v_offset == 0);
345 auto& offset_pack = std::get<0>(pack);
346 size_t n_params = offset_pack.uc_offset;
349 std::mt19937 gen(config.seed);
350 std::uniform_int_distribution direction_sampler(0, 1);
351 std::uniform_real_distribution unif_sampler(0., 1.);
356 Eigen::MatrixXd tp_mat(offset_pack.tp_offset, 2);
357 Eigen::Map<Eigen::VectorXd> tp_val(tp_mat.col(0).data(), offset_pack.tp_offset);
358 Eigen::Map<Eigen::VectorXd> tp_adj(tp_mat.col(1).data(), offset_pack.tp_offset);
359 Eigen::VectorXd constrained(offset_pack.c_offset);
360 Eigen::Matrix<size_t, Eigen::Dynamic, 1> visit(offset_pack.v_offset);
362 constrained.setZero();
371 Eigen::MatrixXd cache_mat(n_params, 18);
373 Eigen::Map<Eigen::VectorXd> p_bb(cache_mat.col(0).data(), n_params);
374 Eigen::Map<Eigen::VectorXd> p_bb_scaled(cache_mat.col(1).data(), n_params);
375 Eigen::Map<Eigen::VectorXd> p_bf(cache_mat.col(2).data(), n_params);
376 Eigen::Map<Eigen::VectorXd> p_bf_scaled(cache_mat.col(3).data(), n_params);
377 Eigen::Map<Eigen::VectorXd> p_fb(cache_mat.col(4).data(), n_params);
378 Eigen::Map<Eigen::VectorXd> p_fb_scaled(cache_mat.col(5).data(), n_params);
379 Eigen::Map<Eigen::VectorXd> p_ff(cache_mat.col(6).data(), n_params);
380 Eigen::Map<Eigen::VectorXd> p_ff_scaled(cache_mat.col(7).data(), n_params);
383 Eigen::Map<Eigen::VectorXd> theta_bb(cache_mat.col(8).data(), n_params);
384 Eigen::Map<Eigen::VectorXd> theta_bb_adj(cache_mat.col(9).data(), n_params);
385 Eigen::Map<Eigen::VectorXd> theta_ff(cache_mat.col(10).data(), n_params);
386 Eigen::Map<Eigen::VectorXd> theta_ff_adj(cache_mat.col(11).data(), n_params);
387 Eigen::Map<Eigen::VectorXd> theta_curr(cache_mat.col(12).data(), n_params);
388 Eigen::Map<Eigen::VectorXd> theta_curr_adj(cache_mat.col(13).data(), n_params);
389 Eigen::Map<Eigen::VectorXd> theta_prime(cache_mat.col(14).data(), n_params);
395 Eigen::Map<Eigen::VectorXd> rho_f(cache_mat.col(15).data(), n_params);
396 Eigen::Map<Eigen::VectorXd> rho_b(cache_mat.col(16).data(), n_params);
397 Eigen::Map<Eigen::VectorXd> rho(cache_mat.col(17).data(), n_params);
400 Eigen::VectorXd tree_cache(n_params * 7 * config.max_depth);
401 tree_cache.setZero();
406 theta_bb.data(), theta_bb_adj.data(),
407 tp_val.data(), tp_adj.data(),
408 constrained.data(), visit.data() ));
410 theta_ff.data(), theta_ff_adj.data(),
411 tp_val.data(), tp_adj.data(),
412 constrained.data(), visit.data() ));
414 theta_curr.data(), theta_curr_adj.data(),
415 tp_val.data(), tp_adj.data(),
416 constrained.data(), visit.data() ));
419 auto size_pack = theta_bb_ad_expr.bind_cache_size();
420 Eigen::VectorXd ad_val_buf(size_pack(0));
421 Eigen::VectorXd ad_adj_buf(size_pack(1));
422 theta_bb_ad_expr.bind_cache({ad_val_buf.data(), ad_adj_buf.data()});
423 theta_ff_ad_expr.bind_cache({ad_val_buf.data(), ad_adj_buf.data()});
424 theta_curr_ad_expr.bind_cache({ad_val_buf.data(), ad_adj_buf.data()});
429 theta_curr.data(),
nullptr,
430 tp_val.data(),
nullptr,
431 constrained.data(), visit.data()));
432 program.init_params(gen, config.prune);
435 double potential_prev = -ad::evaluate(theta_curr_ad_expr);
438 using var_adapter_policy_t =
typename
443 const double log_eps = std::log(
446 theta_curr_ad_expr, theta_curr,
447 theta_curr_adj, tp_adj,
448 gen, momentum_handler));
454 n_params, config.warmup, config.var_config.init_buffer,
455 config.var_config.term_buffer, config.var_config.window_base
464 stopwatch_warmup.
start();
466 for (
size_t i = 0; i < config.samples + config.warmup; ++i) {
469 if (i == config.warmup) {
470 stopwatch_warmup.
stop();
471 stopwatch_sampling.
start();
474 logger.printProgress(i);
477 theta_bb = theta_curr;
480 theta_ff_adj = theta_bb_adj;
484 double log_sum_weight = 0.;
487 size_t n_leapfrog = 0;
488 double sum_metro_prob = 0.;
491 momentum_handler.sample(p_bb, gen);
497 p_bb_scaled = momentum_handler.dkinetic_dr(p_bb);
498 p_bf_scaled = p_bb_scaled;
499 p_fb_scaled = p_bb_scaled;
500 p_ff_scaled = p_bb_scaled;
505 const double kinetic = momentum_handler.kinetic(p_bb);
512 for (
size_t depth = 0; depth < config.max_depth; ++depth) {
518 double log_sum_weight_subtree = math::neg_inf<double>;
520 int8_t v = 2 * direction_sampler(gen) - 1;
524 theta_bb_ad_expr, theta_bb, theta_bb_adj, tp_adj,
527 p_bf, p_bb, p_bf_scaled, p_bb_scaled, rho_b,
529 n_leapfrog, log_sum_weight_subtree, sum_metro_prob,
531 v, std::exp(step_adapter.
log_eps), ham_prev
535 p_fb_scaled = p_bb_scaled;
538 unif_sampler, gen, momentum_handler,
543 theta_ff_ad_expr, theta_ff, theta_ff_adj, tp_adj,
546 p_fb, p_ff, p_fb_scaled, p_ff_scaled, rho_f,
548 n_leapfrog, log_sum_weight_subtree, sum_metro_prob,
550 v, std::exp(step_adapter.
log_eps), ham_prev
554 p_bf_scaled = p_ff_scaled;
557 unif_sampler, gen, momentum_handler,
562 if (!output.
valid)
break;
566 if (log_sum_weight_subtree > log_sum_weight) {
567 theta_curr = theta_prime;
570 double p = std::exp(log_sum_weight_subtree - log_sum_weight);
572 theta_curr = theta_prime;
578 log_sum_weight =
math::lse(log_sum_weight, log_sum_weight_subtree);
600 if (i < config.warmup) {
603 step_adapter.
adapt(sum_metro_prob /
static_cast<double>(n_leapfrog));
606 if constexpr (std::is_same_v<var_adapter_policy_t, diag_var> ||
607 std::is_same_v<var_adapter_policy_t, dense_var>) {
608 const bool update = var_adapter.adapt(theta_curr, momentum_handler.get_m_inverse());
611 std::exp(step_adapter.
log_eps),
612 theta_curr_ad_expr, theta_curr,
613 theta_curr_adj, tp_adj,
614 gen, momentum_handler) );
615 step_adapter.
reset();
616 step_adapter.
init(log_eps);
621 if (i == config.warmup - 1) {
627 if (i >= config.warmup) {
628 res.cont_samples.row(i-config.warmup) = theta_curr;
634 stopwatch_sampling.
stop();
637 res.warmup_time = stopwatch_warmup.
elapsed();
638 res.sampling_time = stopwatch_sampling.
elapsed();
643 template <
class ExprType
644 ,
class NUTSConfigType = NUTSConfig<>>
645 inline auto nuts(
const ExprType& expr,
646 const NUTSConfigType& config = NUTSConfigType())
649 [](
auto& program,
const auto& config,
650 const auto& pack,
auto& res) {