autoppl  v0.8
A C++ template library for probabilistic programming
init_params.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <random>
4 #include <autoppl/math/math.hpp>
5 
6 namespace ppl {
7 namespace expr {
8 
20 template <class ProgramType
21  , class GenType>
22 inline void init_params(ProgramType& program,
23  GenType& gen,
24  bool prune = true,
25  double radius = 2.)
26 {
27  auto& model = program.get_model();
28 
29  // default initialization method
30  std::uniform_real_distribution cont_dist(-radius, radius);
31  auto init_params__ = [&](auto& eq_node) {
32  auto& var = eq_node.get_variable();
33  using var_t = std::decay_t<decltype(var)>;
34  if constexpr (util::is_param_v<var_t>) {
35  var.init(gen, cont_dist);
36  }
37  };
38  model.traverse(init_params__);
39 
40  // prune if set to true
41  if (!prune) return;
42 
43  int n_param_entities = 0;
44  auto n_param_entities__ = [&](const auto& eq_node) {
45  auto& var = eq_node.get_variable();
46  using var_t = std::decay_t<decltype(var)>;
47  if constexpr (util::is_param_v<var_t>) {
48  ++n_param_entities;
49  }
50  };
51  model.traverse(n_param_entities__);
52 
53  auto log_pdf = program.log_pdf();
54  for (int i = 0; i < n_param_entities && log_pdf == math::neg_inf<double>; ++i) {
55 
56  // evaluate log-pdf for 2 purposes:
57  // 1) evaluates unconstrained to constrained automatically
58  // 2) log-pdf needs to be checked that it is not -inf later
59  bool modified = false;
60  auto prune_params__ = [&](auto& eq_node) {
61  auto& var = eq_node.get_variable();
62  const auto& dist = eq_node.get_distribution();
63  using var_t = std::decay_t<decltype(var)>;
64  if constexpr (util::is_param_v<var_t>) {
65  bool curr_modified = dist.prune(var, gen);
66  if (curr_modified) {
67  var.inv_eval();
68  }
69  modified = modified || curr_modified;
70  }
71  };
72  model.traverse(prune_params__);
73 
74  // if no unconstrained parameters were modified, log_pdf won't change anymore
75  // can early exit
76  if (!modified) break;
77 
78  log_pdf = program.log_pdf();
79  }
80 
81  assert(log_pdf != math::neg_inf<double>);
82 }
83 
84 } // namespace expr
85 } // namespace ppl
var_traits.hpp
ppl::expr::init_params
void init_params(ProgramType &program, GenType &gen, bool prune=true, double radius=2.)
Definition: init_params.hpp:22
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ppl
Definition: bounded.hpp:11
math.hpp