autoppl  v0.8
A C++ template library for probabilistic programming
bar_eq.hpp
Go to the documentation of this file.
1 #pragma once
3 
4 #define PPL_VAR_DIST_CONT_DISC_MATCH \
5  "A continuous variable can only be assigned to a continuous distribution. " \
6  "A discrete variable can only be assigned to a discrete distribution. "
7 
8 namespace ppl {
9 namespace expr {
10 namespace model {
11 
17 template <class VarType
18  , class DistType>
19 struct BarEqNode: util::ModelExprBase<BarEqNode<VarType, DistType>>
20 {
21  using var_t = VarType;
22  using dist_t = DistType;
23 
24  static_assert(util::is_dist_assignable_v<var_t>);
25  static_assert(util::is_dist_expr_v<dist_t>);
26 
27  static_assert((util::var_traits<var_t>::is_cont_v &&
32 
33  using dist_value_t = typename
35 
36  BarEqNode(const var_t& var,
37  const dist_t& dist) noexcept
38  : var_{var}
39  , dist_{dist}
40  {}
41 
47  template <class BarEqNodeFunc>
48  void traverse(BarEqNodeFunc&& eq_f)
49  {
50  using this_t = BarEqNode<VarType, DistType>;
51  eq_f(static_cast<this_t&>(*this));
52  }
53 
54  template <class BarEqNodeFunc>
55  void traverse(BarEqNodeFunc&& eq_f) const
56  {
57  using this_t = BarEqNode<VarType, DistType>;
58  eq_f(static_cast<const this_t&>(*this));
59  }
60 
61  auto pdf() {
62  var_.eval();
63  return dist_.pdf(var_);
64  }
65 
66  auto log_pdf() {
67  var_.eval();
68  return dist_.log_pdf(var_);
69  }
70 
71  template <class PtrPackType>
72  auto ad_log_pdf(const PtrPackType& pack) const
73  {
74  if constexpr (util::is_param_v<var_t>) {
75  return dist_.ad_log_pdf(var_, pack) +
76  var_.logj_ad(pack);
77  } else {
78  return dist_.ad_log_pdf(var_, pack);
79  }
80  }
81 
82  template <class PtrPackType>
83  void bind(const PtrPackType& pack)
84  {
85  if constexpr (var_t::has_param) {
86  var_.bind(pack);
87  }
88  dist_.bind(pack);
89  }
90 
91  void activate_refcnt() const {
92  var_.activate_refcnt();
93  dist_.activate_refcnt();
94  }
95 
96  var_t& get_variable() { return var_; }
97  const var_t& get_variable() const { return var_; }
98  dist_t& get_distribution() { return dist_; }
99  const dist_t& get_distribution() const { return dist_; }
100 
101 private:
102  var_t var_;
103  dist_t dist_;
104 };
105 
106 } // namespace model
107 } // namespace expr
108 } // namespace ppl
ppl::expr::model::BarEqNode::var_t
VarType var_t
Definition: bar_eq.hpp:21
ppl::expr::model::BarEqNode::BarEqNode
BarEqNode(const var_t &var, const dist_t &dist) noexcept
Definition: bar_eq.hpp:36
ppl::expr::model::BarEqNode::pdf
auto pdf()
Definition: bar_eq.hpp:61
ppl::util::ModelExprBase
Definition: model_expr_traits.hpp:19
ppl::expr::model::BarEqNode::ad_log_pdf
auto ad_log_pdf(const PtrPackType &pack) const
Definition: bar_eq.hpp:72
ppl::util::var_traits
Definition: var_traits.hpp:40
ppl::util::dist_expr_traits
Definition: dist_expr_traits.hpp:40
ppl::expr::model::BarEqNode::log_pdf
auto log_pdf()
Definition: bar_eq.hpp:66
ppl::expr::model::BarEqNode::dist_value_t
typename util::dist_expr_traits< dist_t >::dist_value_t dist_value_t
Definition: bar_eq.hpp:34
ppl::expr::model::BarEqNode::traverse
void traverse(BarEqNodeFunc &&eq_f)
Definition: bar_eq.hpp:48
ppl::expr::model::BarEqNode::dist_t
DistType dist_t
Definition: bar_eq.hpp:22
ppl::expr::model::BarEqNode::bind
void bind(const PtrPackType &pack)
Definition: bar_eq.hpp:83
ppl::util::dist_expr_traits::dist_value_t
typename DistExprType::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:42
ppl::expr::model::BarEqNode
Definition: bar_eq.hpp:20
ppl::expr::model::BarEqNode::get_variable
const var_t & get_variable() const
Definition: bar_eq.hpp:97
ppl::expr::model::BarEqNode::traverse
void traverse(BarEqNodeFunc &&eq_f) const
Definition: bar_eq.hpp:55
ppl::expr::model::BarEqNode::get_distribution
dist_t & get_distribution()
Definition: bar_eq.hpp:98
ppl
Definition: bounded.hpp:11
PPL_VAR_DIST_CONT_DISC_MATCH
#define PPL_VAR_DIST_CONT_DISC_MATCH
Definition: bar_eq.hpp:4
ppl::expr::model::BarEqNode::get_distribution
const dist_t & get_distribution() const
Definition: bar_eq.hpp:99
traits.hpp
ppl::expr::model::BarEqNode::get_variable
var_t & get_variable()
Definition: bar_eq.hpp:96
ppl::expr::model::BarEqNode::activate_refcnt
void activate_refcnt() const
Definition: bar_eq.hpp:91