autoppl  v0.8
A C++ template library for probabilistic programming
constant.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/constant.hpp>
4 
5 #define PPL_CONSTANT_SHAPE_UNSUPPORTED \
6  "Unsupported shape for constants. "
7 
8 namespace ppl {
9 namespace expr {
10 namespace var {
11 
12 template <class ValueType
13  , class ShapeType=ppl::scl>
14 struct Constant
15 {
16  static_assert(!util::is_shape_v<ShapeType>,
18 };
19 
20 template <class ValueType>
21 struct Constant<ValueType, ppl::scl>:
22  util::VarExprBase<Constant<ValueType, ppl::scl>>
23 {
24  using value_t = ValueType;
25  using shape_t = ppl::scl;
26  static constexpr bool has_param = false;
27 
28  Constant(value_t c) : c_{c} {}
29 
30  template <class Func>
31  void traverse(Func&&) const {}
32 
33  value_t eval() const { return c_; }
34  value_t get() const { return c_; }
35  constexpr size_t size() const { return 1; }
36 
37  template <class PtrPackType>
38  auto ad(const PtrPackType&) const
39  { return ad::constant(c_); }
40 
41  void activate_refcnt() const {}
42 
43 private:
44  value_t c_;
45 };
46 
47 template <class ValueType>
48 struct Constant<ValueType, ppl::vec>:
49  util::VarExprBase<Constant<ValueType, ppl::vec>>
50 {
51  using value_t = ValueType;
52  using shape_t = ppl::vec;
53  static constexpr bool has_param = false;
54 
55  template <class T>
56  Constant(const Eigen::EigenBase<T>& c) : c_{c} {}
57 
58  template <class Func>
59  void traverse(Func&&) const {}
60 
61  const auto& eval() const { return c_; }
62  const auto& get() const { return c_; }
63  size_t size() const { return c_.size(); }
64  size_t rows() const { return c_.rows(); }
65  constexpr size_t cols() const { return 1; }
66 
67  template <class PtrPackType>
68  auto ad(const PtrPackType&) const
69  { return ad::constant_view(c_.data(), rows()); }
70 
71  void activate_refcnt() const {}
72 
73 private:
74  Eigen::Matrix<value_t, Eigen::Dynamic, 1> c_;
75 };
76 
77 template <class ValueType>
78 struct Constant<ValueType, ppl::mat>:
79  util::VarExprBase<Constant<ValueType, ppl::mat>>
80 {
81  using value_t = ValueType;
82  using shape_t = ppl::mat;
83  static constexpr bool has_param = false;
84 
85  template <class T>
86  Constant(const Eigen::EigenBase<T>& c) : c_{c} {}
87 
88  template <class Func>
89  void traverse(Func&&) const {}
90 
91  const auto& eval() const { return c_; }
92  const auto& get() const { return c_; }
93  size_t size() const { return c_.size(); }
94  size_t rows() const { return c_.rows(); }
95  size_t cols() const { return c_.cols(); }
96 
97  template <class PtrPackType>
98  auto ad(const PtrPackType&) const
99  { return ad::constant_view(c_.data(), rows(), cols()); }
100 
101  void activate_refcnt() const {}
102 
103 private:
104  Eigen::Matrix<value_t, Eigen::Dynamic, Eigen::Dynamic> c_;
105 };
106 
107 } // namespace var
108 } // namespace expr
109 } // namespace ppl
110 
111 #undef PPL_CONSTANT_VEC_UNSUPPORTED
112 #undef PPL_CONSTANT_MAT_UNSUPPORTED
ppl::expr::var::Constant< ValueType, ppl::mat >::activate_refcnt
void activate_refcnt() const
Definition: constant.hpp:101
ppl::expr::var::Constant< ValueType, ppl::mat >::shape_t
ppl::mat shape_t
Definition: constant.hpp:82
ppl::expr::var::Constant< ValueType, ppl::mat >::traverse
void traverse(Func &&) const
Definition: constant.hpp:89
ppl::expr::var::Constant< ValueType, ppl::scl >::activate_refcnt
void activate_refcnt() const
Definition: constant.hpp:41
ppl::expr::var::Constant< ValueType, ppl::vec >::rows
size_t rows() const
Definition: constant.hpp:64
var_expr_traits.hpp
ppl::expr::var::Constant< ValueType, ppl::mat >::cols
size_t cols() const
Definition: constant.hpp:95
ppl::expr::var::Constant< ValueType, ppl::vec >::cols
constexpr size_t cols() const
Definition: constant.hpp:65
ppl::expr::var::Constant< ValueType, ppl::mat >::size
size_t size() const
Definition: constant.hpp:93
ppl::expr::var::Constant< ValueType, ppl::scl >::ad
auto ad(const PtrPackType &) const
Definition: constant.hpp:38
ppl::expr::var::Constant< ValueType, ppl::scl >::eval
value_t eval() const
Definition: constant.hpp:33
ppl::expr::var::Constant< ValueType, ppl::vec >::activate_refcnt
void activate_refcnt() const
Definition: constant.hpp:71
ppl::expr::var::Constant< ValueType, ppl::vec >::traverse
void traverse(Func &&) const
Definition: constant.hpp:59
ppl::vec
ad::vec vec
Definition: shape_traits.hpp:17
ppl::expr::var::Constant< ValueType, ppl::scl >::value_t
ValueType value_t
Definition: constant.hpp:24
ppl::expr::var::Constant< ValueType, ppl::mat >::value_t
ValueType value_t
Definition: constant.hpp:81
ppl::expr::var::Constant< ValueType, ppl::mat >::eval
const auto & eval() const
Definition: constant.hpp:91
ppl::expr::var::Constant< ValueType, ppl::mat >::Constant
Constant(const Eigen::EigenBase< T > &c)
Definition: constant.hpp:86
ppl::expr::var::Constant< ValueType, ppl::scl >::get
value_t get() const
Definition: constant.hpp:34
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::expr::var::Constant< ValueType, ppl::vec >::Constant
Constant(const Eigen::EigenBase< T > &c)
Definition: constant.hpp:56
ppl::expr::var::Constant
Definition: constant.hpp:15
ppl::scl
ad::scl scl
Definition: shape_traits.hpp:16
ppl::expr::var::Constant< ValueType, ppl::mat >::ad
auto ad(const PtrPackType &) const
Definition: constant.hpp:98
ppl::expr::var::Constant< ValueType, ppl::mat >::get
const auto & get() const
Definition: constant.hpp:92
ppl::expr::var::Constant< ValueType, ppl::scl >::Constant
Constant(value_t c)
Definition: constant.hpp:28
ppl::expr::var::Constant< ValueType, ppl::vec >::shape_t
ppl::vec shape_t
Definition: constant.hpp:52
ppl::expr::var::Constant< ValueType, ppl::vec >::eval
const auto & eval() const
Definition: constant.hpp:61
ppl::expr::var::Constant< ValueType, ppl::mat >::rows
size_t rows() const
Definition: constant.hpp:94
ppl::expr::var::Constant< ValueType, ppl::scl >::traverse
void traverse(Func &&) const
Definition: constant.hpp:31
ppl::expr::var::Constant< ValueType, ppl::vec >::size
size_t size() const
Definition: constant.hpp:63
ppl::expr::var::Constant< ValueType, ppl::vec >::value_t
ValueType value_t
Definition: constant.hpp:51
PPL_CONSTANT_SHAPE_UNSUPPORTED
#define PPL_CONSTANT_SHAPE_UNSUPPORTED
Definition: constant.hpp:5
ppl::expr::var::Constant< ValueType, ppl::vec >::ad
auto ad(const PtrPackType &) const
Definition: constant.hpp:68
ppl::expr::var::Constant< ValueType, ppl::scl >::size
constexpr size_t size() const
Definition: constant.hpp:35
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::mat
ad::mat mat
Definition: shape_traits.hpp:18
ppl::expr::var::Constant< ValueType, ppl::scl >::shape_t
ppl::scl shape_t
Definition: constant.hpp:25
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::expr::var::Constant< ValueType, ppl::vec >::get
const auto & get() const
Definition: constant.hpp:62