autoppl  v0.8
A C++ template library for probabilistic programming
dot.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/dot.hpp>
4 
5 #define PPL_DOT_MAT_VEC \
6  "Dot product is only supported for matrix as lhs argument " \
7  "and a matrix or vector as rhs argument. "
8 
9 namespace ppl {
10 namespace expr {
11 namespace var {
12 
23 template <class LHSVarExprType
24  , class RHSVarExprType>
25 class DotNode:
26  util::VarExprBase<DotNode<LHSVarExprType, RHSVarExprType>>
27 {
28  using lhs_t = LHSVarExprType;
29  using rhs_t = RHSVarExprType;
30 
31  static_assert(util::is_var_expr_v<lhs_t>);
32  static_assert(util::is_var_expr_v<rhs_t>);
33  static_assert(util::is_mat_v<lhs_t> &&
34  (util::is_vec_v<rhs_t> || util::is_mat_v<rhs_t>),
36 
37 public:
38  using value_t = std::common_type_t<
41  >;
42  using shape_t = ad::core::details::dot_shape_t<lhs_t, rhs_t>;
43  static constexpr bool has_param =
44  lhs_t::has_param || rhs_t::has_param;
45 
46  DotNode(const lhs_t& lhs,
47  const rhs_t& rhs)
48  : lhs_{lhs}
49  , rhs_{rhs}
50  {}
51 
52  template <class Func>
53  void traverse(Func&&) const {}
54 
55  auto eval() { return lhs_.eval() * rhs_.eval(); }
56  auto get() { return lhs_.get() * rhs_.get(); }
57  size_t size() const { return rows() * cols(); }
58  size_t rows() const { return lhs_.rows(); }
59  size_t cols() const { return rhs_.cols(); }
60 
61  template <class PtrPackType>
62  auto ad(const PtrPackType& pack) const
63  {
64  return ad::dot(lhs_.ad(pack),
65  rhs_.ad(pack));
66  }
67 
68  template <class PtrPackType>
69  void bind(const PtrPackType& pack)
70  {
71  if constexpr (lhs_t::has_param) {
72  lhs_.bind(pack);
73  }
74  if constexpr (rhs_t::has_param) {
75  rhs_.bind(pack);
76  }
77  }
78 
79  void activate_refcnt() const {
80  lhs_.activate_refcnt();
81  rhs_.activate_refcnt();
82  }
83 
84 private:
85  lhs_t lhs_;
86  rhs_t rhs_;
87 };
88 
89 } // namespace var
90 } // namespace expr
91 
95 template <class LHSVarExprType
96  , class RHSVarExprType
97  , class = std::enable_if_t<
98  (util::is_var_v<LHSVarExprType> ||
99  util::is_var_expr_v<LHSVarExprType>) &&
100  (util::is_var_v<RHSVarExprType> ||
101  util::is_var_expr_v<RHSVarExprType>)
102  > >
103 inline constexpr auto dot(const LHSVarExprType& lhs,
104  const RHSVarExprType& rhs)
105 {
108 
109  lhs_t wrap_lhs_expr = lhs;
110  rhs_t wrap_rhs_expr = rhs;
111 
113  wrap_lhs_expr, wrap_rhs_expr);
114 }
115 
116 } // namespace ppl
117 
118 #undef PPL_DOT_MAT_VEC
ppl::expr::var::DotNode
Definition: dot.hpp:27
PPL_DOT_MAT_VEC
#define PPL_DOT_MAT_VEC
Definition: dot.hpp:5
ppl::expr::var::DotNode::has_param
static constexpr bool has_param
Definition: dot.hpp:43
ppl::expr::var::DotNode::ad
auto ad(const PtrPackType &pack) const
Definition: dot.hpp:62
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::var::DotNode::activate_refcnt
void activate_refcnt() const
Definition: dot.hpp:79
ppl::expr::var::DotNode::value_t
std::common_type_t< typename util::var_expr_traits< lhs_t >::value_t, typename util::var_expr_traits< rhs_t >::value_t > value_t
Definition: dot.hpp:41
ppl::expr::var::DotNode::shape_t
ad::core::details::dot_shape_t< lhs_t, rhs_t > shape_t
Definition: dot.hpp:42
ppl::expr::var::DotNode::get
auto get()
Definition: dot.hpp:56
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::expr::var::DotNode::cols
size_t cols() const
Definition: dot.hpp:59
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::var::DotNode::traverse
void traverse(Func &&) const
Definition: dot.hpp:53
ppl::expr::var::DotNode::bind
void bind(const PtrPackType &pack)
Definition: dot.hpp:69
ppl::expr::var::DotNode::size
size_t size() const
Definition: dot.hpp:57
traits.hpp
ppl::dot
constexpr auto dot(const LHSVarExprType &lhs, const RHSVarExprType &rhs)
Definition: dot.hpp:103
ppl::expr::var::DotNode::rows
size_t rows() const
Definition: dot.hpp:58
ppl::expr::var::DotNode::eval
auto eval()
Definition: dot.hpp:55
ppl::expr::var::DotNode::DotNode
DotNode(const lhs_t &lhs, const rhs_t &rhs)
Definition: dot.hpp:46