autoppl  v0.8
A C++ template library for probabilistic programming
binary.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/binary.hpp>
4 
5 #define PPL_BINOP_EQUAL_FIXED_SIZE \
6  "If both lhs and rhs are of fixed size, " \
7  "then they must have the same size. "
8 #define PPL_BINOP_NO_MAT_SUPPORT \
9  "Binary operations with matrices are not supported yet. "
10 
11 namespace ppl {
12 namespace expr {
13 namespace var {
14 
31 template <class BinaryOp
32  , class LHSVarExprType
33  , class RHSVarExprType>
34 struct BinaryNode:
35  util::VarExprBase<BinaryNode<BinaryOp, LHSVarExprType, RHSVarExprType>>
36 {
37 private:
38  using lhs_t = LHSVarExprType;
39  using rhs_t = RHSVarExprType;
40 
41  static_assert(util::is_var_expr_v<lhs_t>);
42  static_assert(util::is_var_expr_v<rhs_t>);
43 
44 public:
45  using value_t = std::common_type_t<
48  >;
49  using shape_t = ad::util::max_shape_t<
52  >;
53  static constexpr bool has_param =
54  lhs_t::has_param || rhs_t::has_param;
55 
56  BinaryNode(const lhs_t& lhs,
57  const rhs_t& rhs)
58  : lhs_{lhs}, rhs_{rhs}
59  {}
60 
61  template <class Func>
62  void traverse(Func&&) const {}
63 
64  auto get() const {
65  auto&& lhs = lhs_.get();
66  auto&& rhs = rhs_.get();
67  return eval_helper(lhs, rhs);
68  }
69 
70  auto eval()
71  {
72  auto&& lhs = lhs_.eval();
73  auto&& rhs = rhs_.eval();
74  return eval_helper(lhs, rhs);
75  }
76 
77  size_t size() const { return std::max(lhs_.size(), rhs_.size()); }
78  size_t rows() const { return std::max(lhs_.rows(), rhs_.rows()); }
79  size_t cols() const { return std::max(lhs_.cols(), rhs_.cols()); }
80 
81  template <class PtrPackType>
82  auto ad(const PtrPackType& pack) const
83  {
84  return BinaryOp::fmap(lhs_.ad(pack),
85  rhs_.ad(pack));
86  }
87 
88  template <class PtrPackType>
89  void bind(const PtrPackType& pack)
90  {
91  if constexpr (lhs_t::has_param) {
92  lhs_.bind(pack);
93  }
94  if constexpr (rhs_t::has_param) {
95  rhs_.bind(pack);
96  }
97  }
98 
99  void activate_refcnt() const {
100  lhs_.activate_refcnt();
101  rhs_.activate_refcnt();
102  }
103 
104 private:
105 
106  template <class LHSType, class RHSType>
107  auto eval_helper(const LHSType& lhs,
108  const RHSType& rhs) const {
109  if constexpr (util::is_scl_v<lhs_t> &&
110  util::is_scl_v<rhs_t>) {
111  return BinaryOp::fmap(lhs, rhs);
112  } else if constexpr (util::is_scl_v<lhs_t>) {
113  return BinaryOp::fmap(lhs, rhs.array()).matrix();
114  } else if constexpr (util::is_scl_v<rhs_t>) {
115  return BinaryOp::fmap(lhs.array(), rhs).matrix();
116  } else {
117  return BinaryOp::fmap(lhs.array(), rhs.array()).matrix();
118  }
119  }
120 
121  lhs_t lhs_;
122  rhs_t rhs_;
123 };
124 
125 namespace details {
126 
127 template <class Op, class LHSType, class RHSType>
128 inline constexpr auto operator_helper(const LHSType& lhs,
129  const RHSType& rhs)
130 {
131  // note: may be reference types if converted to itself
132  using lhs_t = util::convert_to_param_t<LHSType>;
133  using rhs_t = util::convert_to_param_t<RHSType>;
134 
135  lhs_t wrap_lhs_expr = lhs;
136  rhs_t wrap_rhs_expr = rhs;
137 
138  using binary_t = BinaryNode<Op, lhs_t, rhs_t>;
139 
140  return binary_t(wrap_lhs_expr, wrap_rhs_expr);
141 }
142 
143 } // namespace details
144 } // namespace var
145 } // namespace expr
146 } // namespace ppl
147 
148 #undef PPL_BINOP_EQUAL_FIXED_SIZE
149 #undef PPL_BINOP_NO_MAT_SUPPORT
ppl::expr::var::BinaryNode::has_param
static constexpr bool has_param
Definition: binary.hpp:53
ppl::expr::var::BinaryNode::cols
size_t cols() const
Definition: binary.hpp:79
ppl::expr::var::BinaryNode::rows
size_t rows() const
Definition: binary.hpp:78
ppl::expr::var::BinaryNode::ad
auto ad(const PtrPackType &pack) const
Definition: binary.hpp:82
ppl::expr::var::details::operator_helper
constexpr auto operator_helper(const LHSType &lhs, const RHSType &rhs)
Definition: binary.hpp:128
ppl::expr::var::BinaryNode::traverse
void traverse(Func &&) const
Definition: binary.hpp:62
ppl::expr::var::BinaryNode::shape_t
ad::util::max_shape_t< typename util::shape_traits< lhs_t >::shape_t, typename util::shape_traits< rhs_t >::shape_t > shape_t
Definition: binary.hpp:52
ppl::expr::var::BinaryNode
Definition: binary.hpp:36
ppl::expr::var::BinaryNode::get
auto get() const
Definition: binary.hpp:64
ppl::expr::var::BinaryNode::activate_refcnt
void activate_refcnt() const
Definition: binary.hpp:99
ppl::expr::var::BinaryNode::bind
void bind(const PtrPackType &pack)
Definition: binary.hpp:89
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl::expr::var::BinaryNode::value_t
std::common_type_t< typename util::var_expr_traits< lhs_t >::value_t, typename util::var_expr_traits< rhs_t >::value_t > value_t
Definition: binary.hpp:48
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::var::BinaryNode::BinaryNode
BinaryNode(const lhs_t &lhs, const rhs_t &rhs)
Definition: binary.hpp:56
ppl::expr::var::BinaryNode::size
size_t size() const
Definition: binary.hpp:77
traits.hpp
ppl::expr::var::BinaryNode::eval
auto eval()
Definition: binary.hpp:70