autoppl  v0.8
A C++ template library for probabilistic programming
glue.hpp
Go to the documentation of this file.
1 #pragma once
3 
4 namespace ppl {
5 namespace expr {
6 namespace var {
7 
8 template <class LHSExprType
9  , class RHSExprType>
10 struct GlueNode:
11  util::VarExprBase<GlueNode<LHSExprType, RHSExprType>>
12 {
13 private:
14  using lhs_t = LHSExprType;
15  using rhs_t = RHSExprType;
16 
17  static_assert(util::is_var_expr_v<lhs_t>);
18  static_assert(util::is_var_expr_v<rhs_t>);
19 
20 public:
23  static constexpr bool has_param =
26 
27  GlueNode(const lhs_t& lhs,
28  const rhs_t& rhs)
29  : lhs_{lhs}, rhs_{rhs}
30  {}
31 
32  template <class Func>
33  void traverse(Func&& f)
34  {
35  lhs_.traverse(f);
36  rhs_.traverse(f);
37  }
38 
39  template <class Func>
40  void traverse(Func&& f) const
41  {
42  lhs_.traverse(f);
43  rhs_.traverse(f);
44  }
45 
46  auto get() const { return rhs_.get(); }
47 
48  auto eval() {
49  lhs_.eval();
50  return rhs_.eval();
51  }
52 
53  constexpr size_t size() const { return rhs_.size(); }
54  constexpr size_t rows() const { return rhs_.rows(); }
55  constexpr size_t cols() const { return rhs_.cols(); }
56 
57  template <class PtrPackType>
58  auto ad(const PtrPackType& pack) const
59  {
60  return (lhs_.ad(pack), rhs_.ad(pack));
61  }
62 
63  template <class PtrPackType>
64  void bind(const PtrPackType& pack)
65  {
66  if constexpr (lhs_t::has_param) {
67  lhs_.bind(pack);
68  }
69  if constexpr (rhs_t::has_param) {
70  rhs_.bind(pack);
71  }
72  }
73 
74  void activate_refcnt() const {
75  lhs_.activate_refcnt();
76  rhs_.activate_refcnt();
77  }
78 
79 private:
80  lhs_t lhs_;
81  rhs_t rhs_;
82 };
83 
84 } // namespace var
85 } // namespace expr
86 } // namespace ppl
ppl::expr::var::GlueNode::activate_refcnt
void activate_refcnt() const
Definition: glue.hpp:74
ppl::expr::var::GlueNode
Definition: glue.hpp:12
ppl::expr::var::GlueNode::eval
auto eval()
Definition: glue.hpp:48
ppl::expr::var::GlueNode::traverse
void traverse(Func &&f)
Definition: glue.hpp:33
ppl::util::var_expr_traits
Definition: var_expr_traits.hpp:28
ppl::expr::var::GlueNode::shape_t
typename util::shape_traits< rhs_t >::shape_t shape_t
Definition: glue.hpp:22
ppl::expr::var::GlueNode::has_param
static constexpr bool has_param
Definition: glue.hpp:23
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::var::GlueNode::ad
auto ad(const PtrPackType &pack) const
Definition: glue.hpp:58
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl::expr::var::GlueNode::cols
constexpr size_t cols() const
Definition: glue.hpp:55
ppl::expr::var::GlueNode::get
auto get() const
Definition: glue.hpp:46
ppl::expr::var::GlueNode::rows
constexpr size_t rows() const
Definition: glue.hpp:54
ppl::expr::var::GlueNode::GlueNode
GlueNode(const lhs_t &lhs, const rhs_t &rhs)
Definition: glue.hpp:27
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::expr::var::GlueNode::bind
void bind(const PtrPackType &pack)
Definition: glue.hpp:64
ppl::expr::var::GlueNode::value_t
typename util::var_expr_traits< rhs_t >::value_t value_t
Definition: glue.hpp:21
traits.hpp
ppl::expr::var::GlueNode::traverse
void traverse(Func &&f) const
Definition: glue.hpp:40
ppl::expr::var::GlueNode::size
constexpr size_t size() const
Definition: glue.hpp:53