autoppl  v0.8
A C++ template library for probabilistic programming
op_eq.hpp
Go to the documentation of this file.
1 #pragma once
3 
4 namespace ppl {
5 namespace expr {
6 namespace var {
7 
8 struct Eq
9 {
10  template <class LHSType
11  , class RHSType>
12  constexpr static LHSType& eval(LHSType& lhs, const RHSType& rhs)
13  { return lhs = rhs; }
14 
15  template <class LHSType
16  , class RHSType>
17  constexpr static auto eval(const LHSType& lhs, const RHSType& rhs)
18  { return lhs = rhs; }
19 };
20 
21 struct AddEq
22 {
23  template <class LHSType
24  , class RHSType>
25  constexpr static LHSType& eval(LHSType& lhs, const RHSType& rhs)
26  { return lhs += rhs; }
27 
28  template <class LHSType
29  , class RHSType>
30  constexpr static auto eval(const LHSType& lhs, const RHSType& rhs)
31  { return lhs += rhs; }
32 };
33 
34 struct SubEq
35 {
36  template <class LHSType
37  , class RHSType>
38  constexpr static LHSType& eval(LHSType& lhs, const RHSType& rhs)
39  { return lhs -= rhs; }
40 
41  template <class LHSType
42  , class RHSType>
43  constexpr static auto eval(const LHSType& lhs, const RHSType& rhs)
44  { return lhs -= rhs; }
45 };
46 
47 struct MulEq
48 {
49  template <class LHSType
50  , class RHSType>
51  constexpr static LHSType& eval(LHSType& lhs, const RHSType& rhs)
52  { return lhs *= rhs; }
53 
54  template <class LHSType
55  , class RHSType>
56  constexpr static auto eval(const LHSType& lhs, const RHSType& rhs)
57  { return lhs *= rhs; }
58 };
59 
60 struct DivEq
61 {
62  template <class LHSType
63  , class RHSType>
64  constexpr static LHSType& eval(LHSType& lhs, const RHSType& rhs)
65  { return lhs /= rhs; }
66 
67  template <class LHSType
68  , class RHSType>
69  constexpr static auto eval(const LHSType& lhs, const RHSType& rhs)
70  { return lhs /= rhs; }
71 };
72 
73 template <class Op
74  , class TParamViewType
75  , class VarExprType>
76 struct OpEqNode:
77  util::VarExprBase<OpEqNode<Op, TParamViewType, VarExprType>>
78 {
79 private:
80  using op_t = Op;
81  using tp_view_t = TParamViewType;
82  using var_expr_t = VarExprType;
83 
84  static_assert(util::is_tparam_v<tp_view_t>);
85  static_assert(util::is_var_expr_v<var_expr_t>);
86 
87  static_assert(std::is_same_v<
90 
91  static_assert(util::is_scl_v<var_expr_t> ||
92  std::is_same_v<
95 
96 public:
99  static constexpr bool has_param = true;
100 
101  OpEqNode(const tp_view_t& tp_view,
102  const var_expr_t& expr)
103  : tp_view_{tp_view}, expr_{expr}
104  {}
105 
106  template <class Func>
107  void traverse(Func&& f)
108  {
109  static_cast<void>(f);
110  if constexpr (std::is_same_v<Op, Eq>) {
111  f(*this);
112  }
113  }
114 
115  template <class Func>
116  void traverse(Func&& f) const
117  {
118  static_cast<void>(f);
119  if constexpr (std::is_same_v<Op, Eq>) {
120  f(*this);
121  }
122  }
123 
124  auto get() const { return tp_view_.get(); }
125 
126  auto eval() {
127  if constexpr (util::is_scl_v<tp_view_t> &&
128  util::is_scl_v<var_expr_t>) {
129  return op_t::eval(tp_view_.get(), expr_.eval());
130  } else if constexpr (!util::is_scl_v<tp_view_t> &&
131  util::is_scl_v<var_expr_t>) {
132  auto tpa = tp_view_.get().array();
133  return op_t::eval(tpa, expr_.eval());
134  } else {
135  auto tpa = tp_view_.get().array();
136  return op_t::eval(tpa, expr_.eval().array());
137  }
138  }
139 
140  constexpr size_t size() const { return tp_view_.size(); }
141  constexpr size_t rows() const { return tp_view_.rows(); }
142  constexpr size_t cols() const { return tp_view_.cols(); }
143 
144  template <class PtrPackType>
145  auto ad(const PtrPackType& pack) const
146  {
147  return op_t::eval(tp_view_.ad(pack), expr_.ad(pack));
148  }
149 
150  template <class PtrPackType>
151  void bind(const PtrPackType& pack)
152  {
153  if constexpr (tp_view_t::has_param) {
154  tp_view_.bind(pack);
155  }
156  if constexpr (var_expr_t::has_param) {
157  expr_.bind(pack);
158  }
159  }
160 
161  void activate_refcnt() const {
162  tp_view_.activate_refcnt();
163  expr_.activate_refcnt();
164  }
165 
166  auto& get_variable() { return tp_view_; }
167  const auto& get_variable() const { return tp_view_; }
168 
169 private:
170  tp_view_t tp_view_;
171  var_expr_t expr_;
172 };
173 
174 namespace details {
175 
176 template <class Op
177  , class TParamViewType
178  , class VarExprType>
179 constexpr inline auto opeq_helper(const TParamViewType& tp_view,
180  const VarExprType& expr)
181 {
184 
185  tp_view_t wrap_tp_view = tp_view;
186  expr_t wrap_expr = expr;
187 
188  return OpEqNode<Op, tp_view_t, expr_t>(wrap_tp_view, wrap_expr);
189 }
190 
191 } // namespace details
192 } // namespace var
193 } // namespace expr
194 } // namespace ppl
ppl::expr::var::OpEqNode::traverse
void traverse(Func &&f)
Definition: op_eq.hpp:107
ppl::expr::var::Eq::eval
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:17
ppl::expr::var::details::opeq_helper
constexpr auto opeq_helper(const TParamViewType &tp_view, const VarExprType &expr)
Definition: op_eq.hpp:179
ppl::expr::var::OpEqNode::eval
auto eval()
Definition: op_eq.hpp:126
ppl::expr::var::DivEq
Definition: op_eq.hpp:61
ppl::expr::var::OpEqNode::traverse
void traverse(Func &&f) const
Definition: op_eq.hpp:116
ppl::expr::var::OpEqNode::size
constexpr size_t size() const
Definition: op_eq.hpp:140
ppl::expr::var::MulEq::eval
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:56
ppl::expr::var::OpEqNode::has_param
static constexpr bool has_param
Definition: op_eq.hpp:99
ppl::expr::var::OpEqNode::ad
auto ad(const PtrPackType &pack) const
Definition: op_eq.hpp:145
ppl::expr::var::DivEq::eval
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:64
ppl::expr::var::OpEqNode::OpEqNode
OpEqNode(const tp_view_t &tp_view, const var_expr_t &expr)
Definition: op_eq.hpp:101
ppl::expr::var::AddEq::eval
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:25
ppl::expr::var::OpEqNode::shape_t
typename util::shape_traits< tp_view_t >::shape_t shape_t
Definition: op_eq.hpp:98
ppl::expr::var::SubEq::eval
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:43
ppl::expr::var::Eq
Definition: op_eq.hpp:9
ppl::expr::var::OpEqNode::get
auto get() const
Definition: op_eq.hpp:124
ppl::expr::var::AddEq
Definition: op_eq.hpp:22
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::var::OpEqNode::rows
constexpr size_t rows() const
Definition: op_eq.hpp:141
ppl::expr::var::OpEqNode::get_variable
const auto & get_variable() const
Definition: op_eq.hpp:167
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl::expr::var::OpEqNode
Definition: op_eq.hpp:78
ppl::expr::var::DivEq::eval
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:69
ppl::expr::var::OpEqNode::activate_refcnt
void activate_refcnt() const
Definition: op_eq.hpp:161
ppl
Definition: bounded.hpp:11
ppl::expr::var::MulEq::eval
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:51
ppl::expr::var::Eq::eval
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:12
ppl::expr::var::OpEqNode::cols
constexpr size_t cols() const
Definition: op_eq.hpp:142
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::expr::var::OpEqNode::value_t
typename util::var_expr_traits< tp_view_t >::value_t value_t
Definition: op_eq.hpp:97
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::var::OpEqNode::get_variable
auto & get_variable()
Definition: op_eq.hpp:166
ppl::expr::var::SubEq::eval
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:38
ppl::expr::var::MulEq
Definition: op_eq.hpp:48
traits.hpp
ppl::expr::var::OpEqNode::bind
void bind(const PtrPackType &pack)
Definition: op_eq.hpp:151
ppl::expr::var::AddEq::eval
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:30
ppl::expr::var::SubEq
Definition: op_eq.hpp:35