autoppl  v0.8
A C++ template library for probabilistic programming
glue.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <type_traits>
4 
5 namespace ppl {
6 namespace expr {
7 namespace model {
8 
13 template <class LHSNodeType
14  , class RHSNodeType>
15 struct GlueNode:
16  util::ModelExprBase<GlueNode<LHSNodeType, RHSNodeType>>
17 {
18  static_assert(util::is_model_expr_v<LHSNodeType>);
19  static_assert(util::is_model_expr_v<RHSNodeType>);
20 
21  using lhs_t = LHSNodeType;
22  using rhs_t = RHSNodeType;
23 
24  using dist_value_t = std::common_type_t<
27  >;
28 
29  GlueNode(const lhs_t& lhs,
30  const rhs_t& rhs) noexcept
31  : lhs_{lhs}
32  , rhs_{rhs}
33  {}
34 
39  template <class EqNodeFunc>
40  void traverse(EqNodeFunc&& eq_f)
41  {
42  lhs_.traverse(eq_f);
43  rhs_.traverse(eq_f);
44  }
45 
46  template <class EqNodeFunc>
47  void traverse(EqNodeFunc&& eq_f) const
48  {
49  lhs_.traverse(eq_f);
50  rhs_.traverse(eq_f);
51  }
52 
57  auto pdf() { return lhs_.pdf() * rhs_.pdf(); }
58 
63  auto log_pdf() { return lhs_.log_pdf() + rhs_.log_pdf(); }
64 
69  template <class PtrPackType>
70  auto ad_log_pdf(const PtrPackType& pack) const
71  {
72  return (lhs_.ad_log_pdf(pack) +
73  rhs_.ad_log_pdf(pack));
74  }
75 
76  template <class PtrPackType>
77  void bind(const PtrPackType& pack)
78  {
79  lhs_.bind(pack);
80  rhs_.bind(pack);
81  }
82 
83  void activate_refcnt() const {
84  lhs_.activate_refcnt();
85  rhs_.activate_refcnt();
86  }
87 
88 private:
89  lhs_t lhs_;
90  rhs_t rhs_;
91 };
92 
93 } // namespace model
94 } // namespace expr
95 } // namespace ppl
ppl::util::ModelExprBase
Definition: model_expr_traits.hpp:19
ppl::expr::model::GlueNode::bind
void bind(const PtrPackType &pack)
Definition: glue.hpp:77
ppl::expr::model::GlueNode::rhs_t
RHSNodeType rhs_t
Definition: glue.hpp:22
ppl::expr::model::GlueNode::GlueNode
GlueNode(const lhs_t &lhs, const rhs_t &rhs) noexcept
Definition: glue.hpp:29
model_expr_traits.hpp
ppl::expr::model::GlueNode::activate_refcnt
void activate_refcnt() const
Definition: glue.hpp:83
ppl::expr::model::GlueNode::ad_log_pdf
auto ad_log_pdf(const PtrPackType &pack) const
Definition: glue.hpp:70
ppl::expr::model::GlueNode::traverse
void traverse(EqNodeFunc &&eq_f) const
Definition: glue.hpp:47
ppl::expr::model::GlueNode::traverse
void traverse(EqNodeFunc &&eq_f)
Definition: glue.hpp:40
ppl::expr::model::GlueNode
Definition: glue.hpp:17
ppl::expr::model::GlueNode::pdf
auto pdf()
Definition: glue.hpp:57
ppl::expr::model::GlueNode::dist_value_t
std::common_type_t< typename util::model_expr_traits< lhs_t >::dist_value_t, typename util::model_expr_traits< rhs_t >::dist_value_t > dist_value_t
Definition: glue.hpp:27
ppl
Definition: bounded.hpp:11
ppl::expr::model::GlueNode::lhs_t
LHSNodeType lhs_t
Definition: glue.hpp:21
ppl::util::model_expr_traits::dist_value_t
typename T::dist_value_t dist_value_t
Definition: model_expr_traits.hpp:31
ppl::expr::model::GlueNode::log_pdf
auto log_pdf()
Definition: glue.hpp:63