autoppl  v0.8
A C++ template library for probabilistic programming
nuts.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <type_traits>
3 #include <Eigen/Dense>
4 #include <fastad_bits/reverse/core/var_view.hpp>
5 #include <fastad_bits/reverse/core/eval.hpp>
10 #include <autoppl/math/math.hpp>
12 #include <autoppl/mcmc/result.hpp>
18 
19 namespace ppl {
20 namespace mcmc {
21 
32 template <class MatType1, class MatType2, class MatType3>
33 bool check_entropy(const MatType1& rho,
34  const MatType2& p_beg_scaled,
35  const MatType3& p_end_scaled)
36 {
37  return rho.dot(p_beg_scaled) > 0 &&
38  rho.dot(p_end_scaled) > 0;
39 }
40 
60 template <class InputType
61  , class UniformDistType
62  , class GenType
63  , class MomentumHandlerType
64  >
65 TreeOutput build_tree(size_t n_params,
66  InputType& input,
67  uint8_t depth,
68  UniformDistType& unif_sampler,
69  GenType& gen,
70  MomentumHandlerType& momentum_handler,
71  double* tree_cache)
72 {
73  constexpr double delta_max = 1000; // suggested by Gelman
74 
75  // base case
76  if (depth == 0) {
77  double new_potential = leapfrog(input.ad_expr_ref.get(),
78  input.theta_ref.get(),
79  input.theta_adj_ref.get(),
80  input.tp_adj_ref.get(),
81  input.p_most_ref.get(),
82  momentum_handler,
83  input.v * input.epsilon,
84  true // always reuse previous adjoint
85  );
86  double new_kinetic = momentum_handler.kinetic(input.p_most_ref.get());
87  double new_ham = hamiltonian(new_potential, new_kinetic);
88 
89  // update number of leapfrogs
90  ++(input.n_leapfrog_ref.get());
91 
92  // update LSE of weights
93  if (std::isnan(new_ham)) { new_ham = math::inf<double>; }
94  input.log_sum_weight_ref.get() = math::lse(
95  input.log_sum_weight_ref.get(),
96  input.ham - new_ham);
97 
98  // update sum of probabilities
99  input.sum_metro_prob_ref.get() += (input.ham - new_ham > 0) ?
100  1 : std::exp(input.ham - new_ham);
101 
102  // always copy into theta_prime
103  input.theta_prime_ref.get() = input.theta_ref.get();
104 
105  // update momenta of beginning of subtree (moving in the direction of input.v)
106  input.p_beg_ref.get() = input.p_most_ref.get();
107  input.p_beg_scaled_ref.get() =
108  momentum_handler.dkinetic_dr(input.p_most_ref.get());
109 
110  // update momenta of end of subtree (moving in the direction of input.v)
111  input.p_end_ref.get() = input.p_beg_ref.get();
112  input.p_end_scaled_ref.get() = input.p_beg_scaled_ref.get();
113 
114  // update integrated momentum
115  input.rho_ref.get() += input.p_most_ref.get();
116 
117  // return validity and new potential
118  return TreeOutput(
119  (new_ham - input.ham <= delta_max),
120  new_potential
121  );
122  }
123 
124  // recursion
125  Eigen::Map<Eigen::VectorXd> p_end_inner(tree_cache, n_params);
126  Eigen::Map<Eigen::VectorXd> p_end_scaled_inner(tree_cache + n_params, n_params);
127  Eigen::Map<Eigen::VectorXd> rho_first(tree_cache + 2*n_params, n_params);
128  rho_first.setZero();
129  double log_sum_weight_first = math::neg_inf<double>;
130 
131  tree_cache += 3 * n_params; // update position of tree cache
132 
133  // create a new input for first recursion
134  // some references have to rebound
135  InputType first_input = input;
136  first_input.p_end_ref = p_end_inner;
137  first_input.p_end_scaled_ref = p_end_scaled_inner;
138  first_input.rho_ref = rho_first;
139  first_input.log_sum_weight_ref = log_sum_weight_first;
140 
141  // build first subtree
142  TreeOutput first_output =
143  build_tree(n_params, first_input, depth - 1,
144  unif_sampler, gen, momentum_handler,
145  tree_cache);
146 
147  // if first subtree is already invalid, early exit
148  // note that caller will break out of doubling process now,
149  // so we do not have to worry about updating the other momentum vectors
150  if (!first_output.valid) { return first_output; }
151 
152  // second recursion
153  Eigen::Map<Eigen::VectorXd> theta_double_prime(tree_cache, n_params);
154  Eigen::Map<Eigen::VectorXd> p_beg_inner(tree_cache + n_params, n_params);
155  Eigen::Map<Eigen::VectorXd> p_beg_scaled_inner(tree_cache + 2*n_params, n_params);
156  Eigen::Map<Eigen::VectorXd> rho_second(tree_cache + 3*n_params, n_params);
157  rho_second.setZero();
158  double log_sum_weight_second = math::neg_inf<double>;
159 
160  tree_cache += 4 * n_params;
161 
162  // create a new input for second recursion
163  InputType second_input = input;
164  second_input.theta_prime_ref = theta_double_prime;
165  second_input.p_beg_ref = p_beg_inner;
166  second_input.p_beg_scaled_ref = p_beg_scaled_inner;
167  second_input.rho_ref = rho_second;
168  second_input.log_sum_weight_ref = log_sum_weight_second;
169 
170  // build second subtree
171  TreeOutput second_output =
172  build_tree(n_params, second_input, depth - 1,
173  unif_sampler, gen, momentum_handler,
174  tree_cache);
175 
176  // if second subtree is invalid, early exit
177  // note that we must return first output since it has the potential
178  // of the first proposal and we ignore the second proposal
179  if (!second_output.valid) {
180  first_output.valid = false;
181  return first_output;
182  }
183 
184  // create output to return at the end
185  TreeOutput output;
186 
187  // sample proposal and update corresponding potential
188  double log_sum_weight_curr = math::lse(
189  log_sum_weight_first, log_sum_weight_second
190  );
191  input.log_sum_weight_ref.get() = math::lse(
192  input.log_sum_weight_ref.get(), log_sum_weight_curr
193  );
194 
195  // note: accept_prob is mathematically guaranteed to be <= 1
196  double accept_prob = std::exp(log_sum_weight_second - log_sum_weight_curr);
197  bool accept = accept_or_reject(accept_prob, unif_sampler, gen);
198  if (accept) {
199  input.theta_prime_ref.get() =
200  second_input.theta_prime_ref.get();
201  output.potential = second_output.potential;
202  } else {
203  output.potential = first_output.potential;
204  }
205 
206  // check if current subtree is still valid based
207  // on entropy condition
208  auto rho_curr = rho_first + rho_second;
209  input.rho_ref.get() += rho_curr;
210  output.valid =
211  check_entropy(rho_curr,
212  input.p_beg_scaled_ref.get(),
213  input.p_end_scaled_ref.get()) &&
214  check_entropy(rho_first + p_beg_inner,
215  input.p_beg_scaled_ref.get(),
216  p_beg_scaled_inner) &&
217  check_entropy(p_end_inner + rho_second,
218  p_end_scaled_inner,
219  input.p_end_scaled_ref.get());
220 
221  return output;
222 }
223 
234 template <class ADExprType
235  , class MatType
236  , class GenType
237  , class MomentumHandlerType>
238 double find_reasonable_epsilon(double eps,
239  ADExprType& ad_expr,
240  MatType& theta,
241  MatType& theta_adj,
242  MatType& tp_adj,
243  GenType& gen,
244  MomentumHandlerType& momentum_handler)
245 {
246  // See (STAN) for reference: if epsilon is way out of bounds, just return eps
247  if (eps <= 0 || eps > 1e7) return eps;
248 
249  const double diff_bound = std::log(0.8);
250 
251  size_t n_params = theta.rows(); // theta is expected to be vector-like
252 
253  Eigen::MatrixXd mat(n_params, 3);
254  Eigen::Map<Eigen::VectorXd> r(mat.col(0).data(), n_params);
255  Eigen::Map<Eigen::VectorXd> theta_orig(mat.col(1).data(), n_params);
256  Eigen::Map<Eigen::VectorXd> theta_adj_orig(mat.col(2).data(), n_params);
257 
258  // sample momentum vector based on handler
259  momentum_handler.sample(r, gen);
260 
261  // differentiate first to get adjoints and hamiltonian
262  const double potential_orig = -ad::autodiff(ad_expr);
263  double kinetic_orig = momentum_handler.kinetic(r);
264  double ham_orig = hamiltonian(potential_orig, kinetic_orig);
265 
266  // save original value and adjoint
267  theta_orig = theta;
268  theta_adj_orig = theta_adj;
269 
270  // get current hamiltonian after leapfrog
271  double potential_curr = leapfrog(
272  ad_expr, theta, theta_adj, tp_adj,
273  r, momentum_handler, eps, true);
274  double kinetic_curr = momentum_handler.kinetic(r);
275  double ham_curr = hamiltonian(potential_curr, kinetic_curr);
276 
277  int a = (ham_orig - ham_curr > diff_bound) ? 1 : -1;
278 
279  while (1) {
280 
281  // check if break condition holds
282  if ( ((a == 1) && !(ham_orig - ham_curr > diff_bound)) ||
283  ((a == -1) && !(ham_orig - ham_curr < diff_bound)) ) {
284  break;
285  }
286 
287  // update epsilon
288  eps *= (a == -1) ? 0.5 : 2;
289 
290  // copy back original value and adjoint
291  theta = theta_orig;
292  theta_adj = theta_adj_orig;
293 
294  // recompute original hamiltonian with new momentum
295  momentum_handler.sample(r, gen);
296  kinetic_orig = momentum_handler.kinetic(r);
297  ham_orig = hamiltonian(potential_orig, kinetic_orig);
298 
299  // leapfrog and compute current hamiltonian
300  potential_curr = leapfrog(
301  ad_expr, theta, theta_adj, tp_adj,
302  r, momentum_handler, eps, true);
303  kinetic_curr = momentum_handler.kinetic(r);
304  ham_curr = hamiltonian(potential_curr, kinetic_curr);
305 
306  }
307 
308  // copy back original value and adjoint
309  theta = theta_orig;
310  theta_adj = theta_adj_orig;
311 
312  return eps;
313 }
314 
331 template <class ProgramType
332  , class OffsetPackType
333  , class MCMCResultType
334  , class NUTSConfigType = NUTSConfig<>>
335 void nuts_(ProgramType& program,
336  const NUTSConfigType& config,
337  const OffsetPackType& pack,
338  MCMCResultType& res)
339 {
340  assert(std::get<1>(pack).uc_offset == 0);
341  assert(std::get<1>(pack).tp_offset == 0);
342  assert(std::get<1>(pack).c_offset == 0);
343  assert(std::get<1>(pack).v_offset == 0);
344 
345  auto& offset_pack = std::get<0>(pack);
346  size_t n_params = offset_pack.uc_offset;
347 
348  // initialization of meta-variables
349  std::mt19937 gen(config.seed);
350  std::uniform_int_distribution direction_sampler(0, 1);
351  std::uniform_real_distribution unif_sampler(0., 1.);
352 
353  // Transformed parameters, constrained parameter, visit count cache
354  // This can be shared across all AD expressions since only one expression
355  // will be evaluated at a time.
356  Eigen::MatrixXd tp_mat(offset_pack.tp_offset, 2);
357  Eigen::Map<Eigen::VectorXd> tp_val(tp_mat.col(0).data(), offset_pack.tp_offset);
358  Eigen::Map<Eigen::VectorXd> tp_adj(tp_mat.col(1).data(), offset_pack.tp_offset);
359  Eigen::VectorXd constrained(offset_pack.c_offset);
360  Eigen::Matrix<size_t, Eigen::Dynamic, 1> visit(offset_pack.v_offset);
361  tp_mat.setZero();
362  constrained.setZero();
363  visit.setZero();
364 
365  // momentum matrix (for stability reasons we require knowing 4 momentum)
366  // left-subtree backwardmost momentum => bb
367  // left-subtree forwardmost momentum => bf
368  // right-subtree backwardmost momentum => fb
369  // right-subtree forwardmost momentum => ff
370  // scaled versions are based on hamiltonian adjusted covariance matrix
371  Eigen::MatrixXd cache_mat(n_params, 18);
372  cache_mat.setZero();
373  Eigen::Map<Eigen::VectorXd> p_bb(cache_mat.col(0).data(), n_params);
374  Eigen::Map<Eigen::VectorXd> p_bb_scaled(cache_mat.col(1).data(), n_params);
375  Eigen::Map<Eigen::VectorXd> p_bf(cache_mat.col(2).data(), n_params);
376  Eigen::Map<Eigen::VectorXd> p_bf_scaled(cache_mat.col(3).data(), n_params);
377  Eigen::Map<Eigen::VectorXd> p_fb(cache_mat.col(4).data(), n_params);
378  Eigen::Map<Eigen::VectorXd> p_fb_scaled(cache_mat.col(5).data(), n_params);
379  Eigen::Map<Eigen::VectorXd> p_ff(cache_mat.col(6).data(), n_params);
380  Eigen::Map<Eigen::VectorXd> p_ff_scaled(cache_mat.col(7).data(), n_params);
381 
382  // position matrix for thetas and adjoints
383  Eigen::Map<Eigen::VectorXd> theta_bb(cache_mat.col(8).data(), n_params);
384  Eigen::Map<Eigen::VectorXd> theta_bb_adj(cache_mat.col(9).data(), n_params);
385  Eigen::Map<Eigen::VectorXd> theta_ff(cache_mat.col(10).data(), n_params);
386  Eigen::Map<Eigen::VectorXd> theta_ff_adj(cache_mat.col(11).data(), n_params);
387  Eigen::Map<Eigen::VectorXd> theta_curr(cache_mat.col(12).data(), n_params);
388  Eigen::Map<Eigen::VectorXd> theta_curr_adj(cache_mat.col(13).data(), n_params);
389  Eigen::Map<Eigen::VectorXd> theta_prime(cache_mat.col(14).data(), n_params);
390 
391  // integrated momentum vectors (more stable than checking entropy with theta_ff - theta_bb)
392  // forward-subtree => rho_f
393  // backward-subtree => rho_b
394  // combined subtrees => rho
395  Eigen::Map<Eigen::VectorXd> rho_f(cache_mat.col(15).data(), n_params);
396  Eigen::Map<Eigen::VectorXd> rho_b(cache_mat.col(16).data(), n_params);
397  Eigen::Map<Eigen::VectorXd> rho(cache_mat.col(17).data(), n_params);
398 
399  // build-tree helper function cache line
400  Eigen::VectorXd tree_cache(n_params * 7 * config.max_depth);
401  tree_cache.setZero();
402 
403  // AD Expressions for L(theta) (log-pdf up to constant at theta)
404  // Note that these expressions are the only ones used ever.
405  auto theta_bb_ad_expr = program.ad_log_pdf(util::make_ptr_pack(
406  theta_bb.data(), theta_bb_adj.data(),
407  tp_val.data(), tp_adj.data(),
408  constrained.data(), visit.data() ));
409  auto theta_ff_ad_expr = program.ad_log_pdf(util::make_ptr_pack(
410  theta_ff.data(), theta_ff_adj.data(),
411  tp_val.data(), tp_adj.data(),
412  constrained.data(), visit.data() ));
413  auto theta_curr_ad_expr = program.ad_log_pdf(util::make_ptr_pack(
414  theta_curr.data(), theta_curr_adj.data(),
415  tp_val.data(), tp_adj.data(),
416  constrained.data(), visit.data() ));
417 
418  // bind every AD expression to the same cache line
419  auto size_pack = theta_bb_ad_expr.bind_cache_size();
420  Eigen::VectorXd ad_val_buf(size_pack(0));
421  Eigen::VectorXd ad_adj_buf(size_pack(1));
422  theta_bb_ad_expr.bind_cache({ad_val_buf.data(), ad_adj_buf.data()});
423  theta_ff_ad_expr.bind_cache({ad_val_buf.data(), ad_adj_buf.data()});
424  theta_curr_ad_expr.bind_cache({ad_val_buf.data(), ad_adj_buf.data()});
425 
426  // initializes first sample into theta_curr
427  // TODO: allow users to choose how to initialize first point?
428  program.bind(util::make_ptr_pack(
429  theta_curr.data(), nullptr,
430  tp_val.data(), nullptr,
431  constrained.data(), visit.data()));
432  program.init_params(gen, config.prune);
433 
434  // initialize current potential (will be "previous" starting in for-loop)
435  double potential_prev = -ad::evaluate(theta_curr_ad_expr);
436 
437  // initialize momentum handler
438  using var_adapter_policy_t = typename
440  mcmc::MomentumHandler<var_adapter_policy_t> momentum_handler(n_params);
441 
442  // initialize step adapter
443  const double log_eps = std::log(
445  1., // initial epsilon
446  theta_curr_ad_expr, theta_curr,
447  theta_curr_adj, tp_adj,
448  gen, momentum_handler));
449  mcmc::StepAdapter step_adapter(log_eps); // initialize step adapter with initial log-epsilon
450  step_adapter.step_config = config.step_config; // copy step configs from user
451 
452  // initialize variance adapter
454  n_params, config.warmup, config.var_config.init_buffer,
455  config.var_config.term_buffer, config.var_config.window_base
456  );
457 
458  // construct miscellaneous objects
459  auto logger = util::ProgressLogger(config.samples + config.warmup, "NUTS");
460  util::StopWatch<> stopwatch_warmup;
461  util::StopWatch<> stopwatch_sampling;
462 
463  // start timing warmup
464  stopwatch_warmup.start();
465 
466  for (size_t i = 0; i < config.samples + config.warmup; ++i) {
467 
468  // if warmup is finished, stop timing warmup and start timing sampling
469  if (i == config.warmup) {
470  stopwatch_warmup.stop();
471  stopwatch_sampling.start();
472  }
473 
474  logger.printProgress(i);
475 
476  // re-initialize vectors to current theta as the "root" of tree
477  theta_bb = theta_curr;
478  theta_ff = theta_bb;
479  mcmc::reset_autodiff(theta_bb_ad_expr, theta_bb_adj, tp_adj);
480  theta_ff_adj = theta_bb_adj; // no need to differentiate again
481 
482  // initialize values for multinomial sampling
483  // this is the total log sum weight over full tree
484  double log_sum_weight = 0.;
485 
486  // initialize values used to adapt stepsize
487  size_t n_leapfrog = 0;
488  double sum_metro_prob = 0.;
489 
490  // p ~ N(0, M) (depending on momentum handler)
491  momentum_handler.sample(p_bb, gen);
492  p_bf = p_bb;
493  p_fb = p_bb;
494  p_ff = p_bb;
495 
496  // scaled p by hamiltonian dkinetic_dr
497  p_bb_scaled = momentum_handler.dkinetic_dr(p_bb);
498  p_bf_scaled = p_bb_scaled;
499  p_fb_scaled = p_bb_scaled;
500  p_ff_scaled = p_bb_scaled;
501 
502  // re-initialize integrated momentum vectors
503  rho = p_bb;
504 
505  const double kinetic = momentum_handler.kinetic(p_bb);
506  const double ham_prev = mcmc::hamiltonian(potential_prev, kinetic);
507 
508  // Note that this object can be reused since all members
509  // are guaranteed to overwritten by build_tree.
510  mcmc::TreeOutput output;
511 
512  for (size_t depth = 0; depth < config.max_depth; ++depth) {
513 
514  // zero-out subtree integrated momentum vectors
515  rho_b.setZero();
516  rho_f.setZero();
517 
518  double log_sum_weight_subtree = math::neg_inf<double>;
519 
520  int8_t v = 2 * direction_sampler(gen) - 1; // -1 or 1
521  if (v == -1) {
522  auto input = mcmc::TreeInput(
523  // position information to update
524  theta_bb_ad_expr, theta_bb, theta_bb_adj, tp_adj,
525  theta_prime, p_bb,
526  // momentum vectors to update
527  p_bf, p_bb, p_bf_scaled, p_bb_scaled, rho_b,
528  // stats to update to adapt step size at the end
529  n_leapfrog, log_sum_weight_subtree, sum_metro_prob,
530  // other miscellaneous variables
531  v, std::exp(step_adapter.log_eps), ham_prev
532  );
533  rho_f = rho;
534  p_fb = p_bb;
535  p_fb_scaled = p_bb_scaled;
536 
537  output = mcmc::build_tree(n_params, input, depth,
538  unif_sampler, gen, momentum_handler,
539  tree_cache.data());
540  } else {
541  auto input = mcmc::TreeInput(
542  // correct position information to update
543  theta_ff_ad_expr, theta_ff, theta_ff_adj, tp_adj,
544  theta_prime, p_ff,
545  // correct momentum vectors to update
546  p_fb, p_ff, p_fb_scaled, p_ff_scaled, rho_f,
547  // stats to update to adapt step size at the end
548  n_leapfrog, log_sum_weight_subtree, sum_metro_prob,
549  // other miscellaneous variables
550  v, std::exp(step_adapter.log_eps), ham_prev
551  );
552  rho_b = rho;
553  p_bf = p_ff;
554  p_bf_scaled = p_ff_scaled;
555 
556  output = mcmc::build_tree(n_params, input, depth,
557  unif_sampler, gen, momentum_handler,
558  tree_cache.data());
559  }
560 
561  // early break if starting to U-Turn
562  if (!output.valid) break;
563 
564  // if new subtree's weight is greater than previous subtree's weight
565  // always accept!
566  if (log_sum_weight_subtree > log_sum_weight) {
567  theta_curr = theta_prime;
568  potential_prev = output.potential;
569  } else {
570  double p = std::exp(log_sum_weight_subtree - log_sum_weight);
571  if (mcmc::accept_or_reject(p, unif_sampler, gen)) {
572  theta_curr = theta_prime;
573  potential_prev = output.potential;
574  }
575  }
576 
577  // update total log_sum_weight
578  log_sum_weight = math::lse(log_sum_weight, log_sum_weight_subtree);
579 
580  // check if proposals are still
581  // - entroping in the full tree
582  // - entroping from backwards-subtree to forwards-subtree
583  // - entroping from forwards-subtree to backwards-subtree
584  // This is a much stronger than the original paper's entropy condition.
585  // This most likely reduces the depth to avoid unnecessary computation.
586 
587  rho = rho_b + rho_f;
588 
589  bool valid =
590  mcmc::check_entropy(rho, p_bb_scaled, p_ff_scaled) &&
591  mcmc::check_entropy(rho_b + p_fb, p_bb_scaled, p_fb_scaled) &&
592  mcmc::check_entropy(p_bf + rho_f, p_bf_scaled, p_ff_scaled)
593  ;
594 
595  if (!valid) break;
596 
597  } // end tree doubling for-loop
598 
599  // Warmup Adapt!
600  if (i < config.warmup) {
601 
602  // epsilon dual averaging
603  step_adapter.adapt(sum_metro_prob / static_cast<double>(n_leapfrog));
604 
605  // adapt variance only if adapting policy is diag_var or dense_var
606  if constexpr (std::is_same_v<var_adapter_policy_t, diag_var> ||
607  std::is_same_v<var_adapter_policy_t, dense_var>) {
608  const bool update = var_adapter.adapt(theta_curr, momentum_handler.get_m_inverse());
609  if (update) {
610  double log_eps = std::log( mcmc::find_reasonable_epsilon(
611  std::exp(step_adapter.log_eps),
612  theta_curr_ad_expr, theta_curr,
613  theta_curr_adj, tp_adj,
614  gen, momentum_handler) );
615  step_adapter.reset();
616  step_adapter.init(log_eps);
617  }
618  }
619 
620  // if last warmup iteration
621  if (i == config.warmup - 1) {
622  step_adapter.log_eps = step_adapter.log_eps_bar;
623  }
624  }
625 
626  // store sample theta_curr only after burning
627  if (i >= config.warmup) {
628  res.cont_samples.row(i-config.warmup) = theta_curr;
629  }
630 
631  } // end for-loop to sample 1 point
632 
633  // stop timing sampling
634  stopwatch_sampling.stop();
635 
636  // save output results
637  res.warmup_time = stopwatch_warmup.elapsed();
638  res.sampling_time = stopwatch_sampling.elapsed();
639 }
640 
641 } // namespace mcmc
642 
643 template <class ExprType
644  , class NUTSConfigType = NUTSConfig<>>
645 inline auto nuts(const ExprType& expr,
646  const NUTSConfigType& config = NUTSConfigType())
647 {
648  return mcmc::base_mcmc(expr, config,
649  [](auto& program, const auto& config,
650  const auto& pack, auto& res) {
651  res.name = "nuts";
652  mcmc::nuts_(program, config, pack, res);
653  });
654 }
655 
656 } // namespace ppl
var_traits.hpp
ppl::util::StopWatch::stop
void stop()
Definition: stopwatch.hpp:11
ppl::mcmc::StepAdapter::init
void init(double _log_eps)
Definition: step_adapter.hpp:38
ppl::mcmc::base_mcmc
MCMCResult base_mcmc(const ExprType &expr, const ConfigType &config, Sampler f)
Definition: base_mcmc.hpp:13
ppl::util::StopWatch::elapsed
double elapsed() const
Definition: stopwatch.hpp:13
hamiltonian.hpp
ppl::mcmc::StepAdapter::step_config
StepConfig step_config
Definition: step_adapter.hpp:76
ppl::mcmc::build_tree
TreeOutput build_tree(size_t n_params, InputType &input, uint8_t depth, UniformDistType &unif_sampler, GenType &gen, MomentumHandlerType &momentum_handler, double *tree_cache)
Definition: nuts.hpp:65
ppl::mcmc::reset_autodiff
double reset_autodiff(ADExprType &ad_expr, Eigen::MatrixBase< MatType > &adjoints, Eigen::MatrixBase< MatType > &tp_adjoints)
Definition: leapfrog.hpp:19
ppl::mcmc::TreeOutput
Definition: tree_utils.hpp:83
ppl::mcmc::TreeOutput::potential
double potential
Definition: tree_utils.hpp:89
ppl::util::StopWatch::start
void start()
Definition: stopwatch.hpp:10
ppl::mcmc::TreeOutput::valid
bool valid
Definition: tree_utils.hpp:88
ppl::mcmc::VarAdapter
Definition: var_adapter.hpp:31
ptr_pack.hpp
ppl::mcmc::hamiltonian
double hamiltonian(double potential, double kinetic)
Definition: hamiltonian.hpp:10
tree_utils.hpp
ppl::mcmc::TreeInput
Definition: tree_utils.hpp:16
ppl::mcmc::check_entropy
bool check_entropy(const MatType1 &rho, const MatType2 &p_beg_scaled, const MatType3 &p_end_scaled)
Definition: nuts.hpp:33
ppl::mcmc::StepAdapter
Definition: step_adapter.hpp:25
logging.hpp
ppl::nuts_config_traits::var_adapter_policy_t
typename NUTSConfigType::var_adapter_policy_t var_adapter_policy_t
Definition: configs.hpp:35
ppl::mcmc::MomentumHandler
Definition: momentum_handler.hpp:12
ppl::mcmc::leapfrog
double leapfrog(ADExprType &ad_expr, Eigen::MatrixBase< MatType > &theta, Eigen::MatrixBase< MatType > &theta_adj, Eigen::MatrixBase< MatType > &tp_adj, Eigen::MatrixBase< MatType > &r, const MomentumHandlerType &m_handler, double epsilon, bool reuse_adj)
Definition: leapfrog.hpp:52
ppl::util::StopWatch
Definition: stopwatch.hpp:9
stopwatch.hpp
ppl::mcmc::StepAdapter::reset
void reset()
Definition: step_adapter.hpp:63
ppl::util::ProgressLogger
Definition: logging.hpp:19
ppl::NUTSConfig
Definition: configs.hpp:15
ppl::util::make_ptr_pack
constexpr auto make_ptr_pack(UCValPtrType _uc_val=nullptr, UCAdjPtrType _uc_adj=nullptr, TPValPtrType _tp_val=nullptr, TPAdjPtrType _tp_adj=nullptr, CValPtrType _c_val=nullptr, size_t *_v_val=nullptr)
Definition: ptr_pack.hpp:47
sampler_tools.hpp
configs.hpp
ppl::mcmc::find_reasonable_epsilon
double find_reasonable_epsilon(double eps, ADExprType &ad_expr, MatType &theta, MatType &theta_adj, MatType &tp_adj, GenType &gen, MomentumHandlerType &momentum_handler)
Definition: nuts.hpp:238
ppl
Definition: bounded.hpp:11
ppl::mcmc::nuts_
void nuts_(ProgramType &program, const NUTSConfigType &config, const OffsetPackType &pack, MCMCResultType &res)
Definition: nuts.hpp:335
base_mcmc.hpp
ppl::mat
ad::mat mat
Definition: shape_traits.hpp:18
ppl::nuts
auto nuts(const ExprType &expr, const NUTSConfigType &config=NUTSConfigType())
Definition: nuts.hpp:645
ppl::mcmc::StepAdapter::log_eps_bar
double log_eps_bar
Definition: step_adapter.hpp:72
math.hpp
leapfrog.hpp
ppl::math::lse
T lse(T x, T y)
Definition: math.hpp:25
ppl::mcmc::StepAdapter::log_eps
double log_eps
Definition: step_adapter.hpp:71
result.hpp
ppl::mcmc::accept_or_reject
bool accept_or_reject(double p, UniformDistType &&unif_sampler, GenType &&gen)
Definition: sampler_tools.hpp:22
ppl::mcmc::StepAdapter::adapt
void adapt(double alpha_ratio)
Definition: step_adapter.hpp:47