autoppl  v0.8
A C++ template library for probabilistic programming
unary.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/unary.hpp>
4 
5 namespace ppl {
6 namespace expr {
7 namespace var {
8 
19 template <class UnaryOp
20  , class VarExprType>
21 struct UnaryNode:
22  util::VarExprBase<UnaryNode<UnaryOp, VarExprType>>
23 {
24 private:
25  using expr_t = VarExprType;
26 
27  static_assert(util::is_var_expr_v<expr_t>);
28 
29 public:
32  static constexpr bool has_param = expr_t::has_param;
33 
34  UnaryNode(const expr_t& expr)
35  : expr_{expr}
36  {}
37 
38  template <class Func>
39  void traverse(Func&&) const {}
40 
41  auto get() const {
42  return eval_helper(expr_.get());
43  }
44 
45  auto eval() {
46  return eval_helper(expr_.eval());
47  }
48 
49  constexpr size_t size() const { return expr_.size(); }
50  constexpr size_t rows() const { return expr_.rows(); }
51  constexpr size_t cols() const { return expr_.cols(); }
52 
53  template <class PtrPackType>
54  auto ad(const PtrPackType& pack) const
55  {
56  return UnaryOp::fmap(expr_.ad(pack));
57  }
58 
59  template <class PtrPackType>
60  void bind(const PtrPackType& pack)
61  {
62  if constexpr (expr_t::has_param) {
63  expr_.bind(pack);
64  }
65  }
66 
67  void activate_refcnt() const {
68  expr_.activate_refcnt();
69  }
70 
71 private:
72 
73  template <class T>
74  auto eval_helper(const T& x) const {
75  if constexpr (util::is_scl_v<expr_t>) {
76  return UnaryOp::fmap(x);
77  } else {
78  return UnaryOp::fmap(x.array()).matrix();
79  }
80  }
81 
82  expr_t expr_;
83 };
84 
85 } // namespace var
86 } // namespace expr
87 
88 #define PPL_UNARY_FUNC(name, strct) \
89  template <class ExprType \
90  , class = std::enable_if_t< \
91  util::is_valid_op_param_v<ExprType> && \
92  !std::is_arithmetic_v<ExprType> \
93  > > \
94  constexpr inline auto name(const ExprType& expr) \
95  { \
96  using expr_t = util::convert_to_param_t<ExprType>; \
97  expr_t wrap_expr = expr; \
98  using unary_t = expr::var::UnaryNode<ad::core::strct, expr_t>; \
99  return unary_t(wrap_expr); \
100  }
101 
102 PPL_UNARY_FUNC(sin, Sin)
103 PPL_UNARY_FUNC(cos, Cos)
104 PPL_UNARY_FUNC(tan, Tan)
105 PPL_UNARY_FUNC(asin, Arcsin)
106 PPL_UNARY_FUNC(acos, Arccos)
107 PPL_UNARY_FUNC(atan, Arctan)
108 PPL_UNARY_FUNC(exp, Exp)
109 PPL_UNARY_FUNC(log, Log)
110 PPL_UNARY_FUNC(sqrt, Sqrt)
111 
112 } // namespace ppl
ppl::expr::var::UnaryNode::UnaryNode
UnaryNode(const expr_t &expr)
Definition: unary.hpp:34
PPL_UNARY_FUNC
#define PPL_UNARY_FUNC(name, strct)
Definition: unary.hpp:88
ppl::expr::var::UnaryNode::bind
void bind(const PtrPackType &pack)
Definition: unary.hpp:60
ppl::expr::var::UnaryNode::ad
auto ad(const PtrPackType &pack) const
Definition: unary.hpp:54
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::var::UnaryNode::shape_t
typename util::shape_traits< expr_t >::shape_t shape_t
Definition: unary.hpp:31
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl::expr::var::UnaryNode::value_t
typename util::var_expr_traits< expr_t >::value_t value_t
Definition: unary.hpp:30
ppl::expr::var::UnaryNode
Definition: unary.hpp:23
ppl::expr::var::UnaryNode::cols
constexpr size_t cols() const
Definition: unary.hpp:51
ppl::expr::var::UnaryNode::has_param
static constexpr bool has_param
Definition: unary.hpp:32
ppl::expr::var::UnaryNode::traverse
void traverse(Func &&) const
Definition: unary.hpp:39
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::expr::var::UnaryNode::get
auto get() const
Definition: unary.hpp:41
ppl::expr::var::UnaryNode::eval
auto eval()
Definition: unary.hpp:45
ppl::expr::var::UnaryNode::rows
constexpr size_t rows() const
Definition: unary.hpp:50
traits.hpp
ppl::expr::var::UnaryNode::size
constexpr size_t size() const
Definition: unary.hpp:49
ppl::expr::var::UnaryNode::activate_refcnt
void activate_refcnt() const
Definition: unary.hpp:67