autoppl  v0.8
A C++ template library for probabilistic programming
program.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <tuple>
6 
7 namespace ppl {
8 namespace expr {
9 namespace prog {
10 
11 template <class ModelExpr>
13 {
14  using model_t = ModelExpr;
15  ProgramNodeBase(const model_t& model)
16  : model_(model)
17  {}
18 
19  auto& get_model() { return model_; }
20  const auto& get_model() const { return model_; }
21 
22 protected:
24 };
25 
37 template <class TupExprType
38  , class = void>
39 struct ProgramNode;
40 
41 template <class ModelExpr>
42 struct ProgramNode<std::tuple<ModelExpr>,
43  std::enable_if_t<util::is_model_expr_v<ModelExpr>
44  > >:
46  ProgramNode<std::tuple<ModelExpr>,
47  std::enable_if_t<util::is_model_expr_v<ModelExpr>> >
48  >,
49  ProgramNodeBase<ModelExpr>
50 {
52  using typename base_t::model_t;
53  using base_t::model_;
54 
55  ProgramNode(const model_t& model)
56  : base_t(model)
57  {}
58 
59  auto log_pdf() { return model_.log_pdf(); }
60 
61  template <class PtrPackType>
62  auto ad_log_pdf(const PtrPackType& pack) const {
63  return model_.ad_log_pdf(pack);
64  }
65 
66  auto activate() const {
67  auto res = expr::activate(model_);
68  model_.activate_refcnt();
69  return res;
70  }
71 
72  template <class PtrPackType>
73  void bind(const PtrPackType& pack) {
74  model_.bind(pack);
75  }
76 
77  template <class GenType>
78  void init_params(GenType& gen,
79  bool prune = true,
80  double radius = 2.) {
81  expr::init_params(*this, gen, prune, radius);
82  }
83 };
84 
85 template <class TPExpr, class ModelExpr>
86 struct ProgramNode<std::tuple<TPExpr, ModelExpr>,
87  std::enable_if_t<
88  util::is_var_expr_v<TPExpr> &&
89  util::is_model_expr_v<ModelExpr>
90  > >:
92  ProgramNode<std::tuple<TPExpr, ModelExpr>,
93  std::enable_if_t<
94  util::is_var_expr_v<TPExpr> &&
95  util::is_model_expr_v<ModelExpr>> >
96  >,
97  ProgramNodeBase<ModelExpr>
98 {
100  using tp_expr_t = TPExpr;
101  using typename base_t::model_t;
102  using base_t::model_;
103 
104  ProgramNode(const tp_expr_t& tp_expr,
105  const model_t& model)
106  : base_t(model)
107  , tp_expr_(tp_expr)
108  {}
109 
110  auto log_pdf() {
111  tp_expr_.eval();
112  return model_.log_pdf();
113  }
114 
115  template <class PtrPackType>
116  auto ad_log_pdf(const PtrPackType& pack) const {
117  return (tp_expr_.ad(pack), model_.ad_log_pdf(pack));
118  }
119 
120  auto activate() const {
121  auto tp_res = expr::activate(tp_expr_);
122  auto model_res = expr::activate(model_);
123  tp_expr_.activate_refcnt();
124  model_.activate_refcnt();
125 
126  util::OffsetPack cont_res;
127  util::OffsetPack disc_res;
128  cont_res = std::get<0>(model_res);
129  cont_res.tp_offset = std::get<0>(tp_res).tp_offset;
130  disc_res = std::get<1>(model_res);
131  disc_res.tp_offset = std::get<1>(tp_res).tp_offset;
132  return std::make_pair(cont_res, disc_res);
133  }
134 
135  template <class PtrPackType>
136  void bind(const PtrPackType& pack) {
137  tp_expr_.bind(pack);
138  model_.bind(pack);
139  }
140 
141  template <class GenType>
142  void init_params(GenType& gen,
143  bool prune = true,
144  double radius = 2.) {
145  expr::init_params(*this, gen, prune, radius);
146  }
147 
148 private:
149  tp_expr_t tp_expr_;
150 };
151 
152 } // namespace prog
153 } // namespace expr
154 } // namespace ppl
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::bind
void bind(const PtrPackType &pack)
Definition: program.hpp:136
ppl::expr::prog::ProgramNodeBase::model_
model_t model_
Definition: program.hpp:23
ppl::expr::prog::ProgramNode< std::tuple< ModelExpr >, std::enable_if_t< util::is_model_expr_v< ModelExpr > > >::log_pdf
auto log_pdf()
Definition: program.hpp:59
ppl::util::ProgramExprBase
Definition: program_expr_traits.hpp:9
ppl::expr::prog::ProgramNode< std::tuple< ModelExpr >, std::enable_if_t< util::is_model_expr_v< ModelExpr > > >::activate
auto activate() const
Definition: program.hpp:66
init_params.hpp
ppl::expr::prog::ProgramNodeBase::get_model
auto & get_model()
Definition: program.hpp:19
activate.hpp
ppl::expr::prog::ProgramNode
Definition: program.hpp:39
ppl::expr::prog::ProgramNodeBase
Definition: program.hpp:13
ppl::expr::prog::ProgramNode< std::tuple< ModelExpr >, std::enable_if_t< util::is_model_expr_v< ModelExpr > > >::init_params
void init_params(GenType &gen, bool prune=true, double radius=2.)
Definition: program.hpp:78
ppl::expr::prog::ProgramNode< std::tuple< ModelExpr >, std::enable_if_t< util::is_model_expr_v< ModelExpr > > >::bind
void bind(const PtrPackType &pack)
Definition: program.hpp:73
ppl::expr::init_params
void init_params(ProgramType &program, GenType &gen, bool prune=true, double radius=2.)
Definition: init_params.hpp:22
ppl::expr::prog::ProgramNode< std::tuple< ModelExpr >, std::enable_if_t< util::is_model_expr_v< ModelExpr > > >::ProgramNode
ProgramNode(const model_t &model)
Definition: program.hpp:55
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::ad_log_pdf
auto ad_log_pdf(const PtrPackType &pack) const
Definition: program.hpp:116
ppl::expr::prog::ProgramNodeBase::model_t
ModelExpr model_t
Definition: program.hpp:14
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::tp_expr_t
TPExpr tp_expr_t
Definition: program.hpp:100
ppl::expr::prog::ProgramNode< std::tuple< ModelExpr >, std::enable_if_t< util::is_model_expr_v< ModelExpr > > >::ad_log_pdf
auto ad_log_pdf(const PtrPackType &pack) const
Definition: program.hpp:62
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::init_params
void init_params(GenType &gen, bool prune=true, double radius=2.)
Definition: program.hpp:142
ppl::expr::prog::ProgramNodeBase::get_model
const auto & get_model() const
Definition: program.hpp:20
ppl::expr::prog::ProgramNodeBase::ProgramNodeBase
ProgramNodeBase(const model_t &model)
Definition: program.hpp:15
ppl::util::OffsetPack::tp_offset
index_t tp_offset
Definition: offset_pack.hpp:14
ppl
Definition: bounded.hpp:11
ppl::util::OffsetPack
Definition: offset_pack.hpp:9
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::log_pdf
auto log_pdf()
Definition: program.hpp:110
traits.hpp
ppl::expr::activate
auto activate(ExprType &&expr)
Definition: activate.hpp:23
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::activate
auto activate() const
Definition: program.hpp:120
ppl::expr::prog::ProgramNode< std::tuple< TPExpr, ModelExpr >, std::enable_if_t< util::is_var_expr_v< TPExpr > &&util::is_model_expr_v< ModelExpr > > >::ProgramNode
ProgramNode(const tp_expr_t &tp_expr, const model_t &model)
Definition: program.hpp:104