autoppl  v0.8
A C++ template library for probabilistic programming
tparam.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <autoppl/util/value.hpp>
6 #include <fastad_bits/reverse/core/var_view.hpp>
7 
8 namespace ppl {
9 
10 // forward declaration
11 namespace expr {
12 namespace var {
13 
14 template <class Op
15  , class TParamViewType
16  , class VarExprType>
17 struct OpEqNode;
18 struct Eq;
19 
20 } // namespace var
21 } // namespace expr
22 
23 namespace details {
24 
26 {
28 };
29 
30 } // namespace details
31 
32 template <class Derived>
34 
35 template <class ValueType
36  , class ShapeType>
37 struct TParamViewBase<TParamView<ValueType, ShapeType>>
38 {
39  using value_t = ValueType;
40  using shape_t = ShapeType;
43  using id_t = const void*;
44  static constexpr bool has_param = true;
45 
47  size_t rows=1,
48  size_t cols=1) noexcept
49  : i_pack_(i_pack)
50  , var_(util::make_val<value_t, shape_t>(rows, cols))
51  , id_(this)
52  {}
53 
54  template <class VarExprType
55  , class = std::enable_if_t<
56  util::is_valid_op_param_v<VarExprType>
57  > >
58  auto operator=(const VarExprType& expr) const
59  {
61  expr_t wrap_expr = expr;
63  static_cast<const derived_t&>(*this), wrap_expr);
64  }
65 
66  template <class Func>
67  void traverse(Func&&) const {}
68 
69  const var_t& eval() { return get(); }
70 
71  template <class UCValPtrType
72  , class UCAdjPtrType
73  , class CValPtrType>
74  auto ad(const util::PtrPack<UCValPtrType,
75  UCAdjPtrType,
76  value_t*,
77  value_t*,
78  CValPtrType>& pack) const {
79  return ad::VarView<value_t, shape_t>(pack.tp_val + i_pack_->off_pack.tp_offset,
80  pack.tp_adj + i_pack_->off_pack.tp_offset,
81  rows(), cols());
82  }
83 
84  void activate(util::OffsetPack& pack) const {
85  i_pack_->off_pack = pack;
86  pack.tp_offset += size();
87  }
88 
89  void activate_refcnt() const {}
90 
91  template <class PtrPackType>
92  void bind(const PtrPackType& pack)
93  {
94  static_cast<void>(pack);
95  if constexpr (std::is_convertible_v<typename PtrPackType::tp_val_ptr_t, value_t*>) {
96  value_t* tcp = pack.tp_val;
97  util::bind(var_, tcp + i_pack_->off_pack.tp_offset,
98  rows(), cols());
99  }
100  }
101 
102  var_t& get() { return util::get(var_); }
103  const var_t& get() const { return util::get(var_); }
104  constexpr size_t size() const { return util::size(var_); }
105  constexpr size_t rows() const { return util::rows(var_); }
106  constexpr size_t cols() const { return util::cols(var_); }
107  id_t id() const { return id_; }
108 
109 protected:
110  using view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
113  const id_t id_;
114 };
115 
116 template <class ValueType
117  , class ShapeType = ppl::scl>
118 struct TParamView:
119  TParamViewBase<TParamView<ValueType, ShapeType>>,
120  util::VarExprBase<TParamView<ValueType, ShapeType>>,
121  util::TParamBase<TParamView<ValueType, ShapeType>>
122 {
124  using base_t::operator=;
125 
127  size_t rows=1,
128  size_t cols=1) noexcept
129  : base_t(i_pack, rows, cols)
130  {}
131 };
132 
133 template <class ValueType>
134 struct TParamView<ValueType, scl>:
135  TParamViewBase<TParamView<ValueType, scl>>,
136  util::VarExprBase<TParamView<ValueType, scl>>,
137  util::TParamBase<TParamView<ValueType, scl>>
138 {
140  using base_t::operator=;
141 
143  size_t rows=1,
144  size_t cols=1,
145  size_t rel_offset = 0) noexcept
146  : base_t(i_pack, rows, cols)
147  , rel_offset_(rel_offset)
148  {}
149 
150  template <class UCValPtrType
151  , class UCAdjPtrType
152  , class CValPtrType>
153  auto ad(const util::PtrPack<UCValPtrType,
154  UCAdjPtrType,
155  typename base_t::value_t*,
156  typename base_t::value_t*,
157  CValPtrType>& pack) const {
158  base_t::i_pack_->off_pack.tp_offset += rel_offset_;
159  auto&& res = base_t::ad(pack);
160  base_t::i_pack_->off_pack.tp_offset -= rel_offset_;
161  return res;
162  }
163 
164  template <class PtrPackType>
165  void bind(const PtrPackType& pack)
166  {
167  base_t::i_pack_->off_pack.tp_offset += rel_offset_;
168  base_t::bind(pack);
169  base_t::i_pack_->off_pack.tp_offset -= rel_offset_;
170  }
171 
172 private:
173  size_t rel_offset_;
174 };
175 
176 template <class ValueType>
177 struct TParamView<ValueType, vec>:
178  TParamViewBase<TParamView<ValueType, vec>>,
179  util::VarExprBase<TParamView<ValueType, vec>>,
180  util::TParamBase<TParamView<ValueType, vec>>
181 {
183  using base_t::operator=;
184 
186  size_t rows,
187  size_t cols=1) noexcept
188  : base_t(i_pack, rows, cols)
189  {}
190 
191  auto operator[](size_t i) const {
194  }
195 };
196 
197 template <class ValueType
198  , class ShapeType = ppl::scl>
199 struct TParam;
200 
201 template <class ValueType>
202 struct TParam<ValueType, ppl::scl>:
203  TParamView<ValueType, ppl::scl>,
204  util::TParamBase<TParam<ValueType, ppl::scl>>
205 {
207  using base_t::operator=;
208 
209  TParam() noexcept
210  : base_t(&i_pack_, 1, 1)
211  {}
212 
213 private:
214  details::TParamInfoPack i_pack_;
215 };
216 
217 template <class ValueType>
218 struct TParam<ValueType, ppl::vec> :
219  TParamView<ValueType, ppl::vec>,
220  util::TParamBase<TParam<ValueType, ppl::vec>>
221 {
223  using typename base_t::value_t;
224  using base_t::operator=;
225 
226  TParam(size_t n)
227  : base_t(&i_pack_, n, 1)
228  {}
229 
230 private:
231  details::TParamInfoPack i_pack_;
232 };
233 
234 template <class ValueType>
235 struct TParam<ValueType, ppl::mat> :
236  TParamView<ValueType, ppl::mat>,
237  util::TParamBase<TParam<ValueType, ppl::mat>>
238 {
240  using base_t::operator=;
241 
242  TParam(size_t rows, size_t cols)
243  : base_t(&i_pack_, rows, cols)
244  {}
245 
246 private:
247  details::TParamInfoPack i_pack_;
248 };
249 
250 } // namespace ppl
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::size
constexpr size_t size() const
Definition: tparam.hpp:104
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::ad
auto ad(const util::PtrPack< UCValPtrType, UCAdjPtrType, value_t *, value_t *, CValPtrType > &pack) const
Definition: tparam.hpp:74
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::TParamViewBase
TParamViewBase(details::TParamInfoPack *i_pack, size_t rows=1, size_t cols=1) noexcept
Definition: tparam.hpp:46
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::bind
void bind(const PtrPackType &pack)
Definition: tparam.hpp:92
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::var_t
util::var_t< value_t, shape_t > var_t
Definition: tparam.hpp:42
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::activate_refcnt
void activate_refcnt() const
Definition: tparam.hpp:89
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::cols
constexpr size_t cols() const
Definition: tparam.hpp:106
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::rows
constexpr size_t rows() const
Definition: tparam.hpp:105
ppl::TParamView< ValueType, scl >::TParamView
TParamView(details::TParamInfoPack *i_pack, size_t rows=1, size_t cols=1, size_t rel_offset=0) noexcept
Definition: tparam.hpp:142
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::view_t
ad::util::shape_to_raw_view_t< value_t, shape_t > view_t
Definition: tparam.hpp:110
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::value_t
ValueType value_t
Definition: tparam.hpp:39
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::id_t
const void * id_t
Definition: tparam.hpp:43
ppl::vec
ad::vec vec
Definition: shape_traits.hpp:17
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::get
var_t & get()
Definition: tparam.hpp:102
ptr_pack.hpp
offset_pack.hpp
ppl::TParamView< ValueType, vec >::TParamView
TParamView(details::TParamInfoPack *i_pack, size_t rows, size_t cols=1) noexcept
Definition: tparam.hpp:185
ppl::details::TParamInfoPack::off_pack
util::OffsetPack off_pack
Definition: tparam.hpp:27
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::shape_t
ShapeType shape_t
Definition: tparam.hpp:40
ppl::TParamView< ValueType, vec >::operator[]
auto operator[](size_t i) const
Definition: tparam.hpp:191
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::scl
ad::scl scl
Definition: shape_traits.hpp:16
ppl::TParamView
Definition: tparam.hpp:122
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::traverse
void traverse(Func &&) const
Definition: tparam.hpp:67
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::i_pack_
details::TParamInfoPack *const i_pack_
Definition: tparam.hpp:111
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::eval
const var_t & eval()
Definition: tparam.hpp:69
ppl::TParamView< ValueType, scl >::bind
void bind(const PtrPackType &pack)
Definition: tparam.hpp:165
ppl::TParam< ValueType, ppl::mat >::TParam
TParam(size_t rows, size_t cols)
Definition: tparam.hpp:242
ppl::TParam
Definition: tparam.hpp:199
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
value.hpp
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::get
const var_t & get() const
Definition: tparam.hpp:103
ppl::TParamView< ValueType, scl >::ad
auto ad(const util::PtrPack< UCValPtrType, UCAdjPtrType, typename base_t::value_t *, typename base_t::value_t *, CValPtrType > &pack) const
Definition: tparam.hpp:153
ppl::util::OffsetPack::tp_offset
index_t tp_offset
Definition: offset_pack.hpp:14
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::activate
void activate(util::OffsetPack &pack) const
Definition: tparam.hpp:84
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::id
id_t id() const
Definition: tparam.hpp:107
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::var_
view_t var_
Definition: tparam.hpp:112
ppl::expr::var::OpEqNode
Definition: op_eq.hpp:78
ppl::TParam< ValueType, ppl::vec >::TParam
TParam(size_t n)
Definition: tparam.hpp:226
ppl
Definition: bounded.hpp:11
ppl::TParamViewBase< TParamView< ValueType, ShapeType > >::id_
const id_t id_
Definition: tparam.hpp:113
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::mat
ad::mat mat
Definition: shape_traits.hpp:18
ppl::TParamViewBase
Definition: tparam.hpp:33
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::util::TParamBase
Definition: var_traits.hpp:20
ppl::details::TParamInfoPack
Definition: tparam.hpp:26
ppl::TParam< ValueType, ppl::scl >::TParam
TParam() noexcept
Definition: tparam.hpp:209
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::util::OffsetPack
Definition: offset_pack.hpp:9
ppl::util::PtrPack
Definition: ptr_pack.hpp:14
traits.hpp
ppl::TParamView::TParamView
TParamView(details::TParamInfoPack *i_pack, size_t rows=1, size_t cols=1) noexcept
Definition: tparam.hpp:126
ppl::util::make_val
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9