autoppl  v0.8
A C++ template library for probabilistic programming
param.hpp
Go to the documentation of this file.
1 #pragma once
7 #include <fastad_bits/reverse/core/var_view.hpp>
8 
9 #define PPL_PARAMVIEW_SHAPE_UNSUPPORTED \
10  "Unsupported shape for ParamView. "
11 #define PPL_PARAM_SHAPE_UNSUPPORTED \
12  "Unsupported shape for Param. "
13 
14 namespace ppl {
15 namespace details {
16 
18 {
19  size_t refcnt = 0; // total reference count
21 };
22 
23 } // namespace details
24 
41 template <class ValueType
42  , class ShapeType = ppl::scl
43  , class ConstraintType = expr::constraint::Unconstrained>
44 struct ParamView:
45  util::VarExprBase<ParamView<ValueType, ShapeType, ConstraintType>>,
46  util::ParamBase<ParamView<ValueType, ShapeType, ConstraintType>>
47 {
48  using value_t = ValueType;
49  using shape_t = ShapeType;
50  using constraint_t = ConstraintType;
52  using id_t = const void*;
53  static constexpr bool has_param = true;
54 
56  size_t rows=1,
57  size_t cols=1,
58  constraint_t c = constraint_t()) noexcept
59  : i_pack_{i_pack}
60  , transformer_{rows, cols, c}
61  , id_{this}
62  {}
63 
64  template <class Func>
65  void traverse(Func&&) const {}
66 
76  const var_t& eval() {
77  transformer_.inv_transform(i_pack_->refcnt);
78  return transformer_.get_c();
79  }
80 
86  void inv_eval() { transformer_.transform(); }
87 
88  template <class TPValPtrType
89  , class TPAdjPtrType>
90  auto ad(const util::PtrPack<value_t*,
91  value_t*,
92  TPValPtrType,
93  TPAdjPtrType,
94  value_t*>& pack) const {
95  auto curr_pack = pack;
96  curr_pack.uc_val += i_pack_->off_pack.uc_offset;
97  curr_pack.uc_adj += i_pack_->off_pack.uc_offset;
98  curr_pack.c_val += i_pack_->off_pack.c_offset;
99  curr_pack.v_val += i_pack_->off_pack.v_offset;
100  return transformer_.inv_transform_ad(curr_pack, pack,
101  i_pack_->refcnt);
102  }
103 
104  template <class TPValPtrType
105  , class TPAdjPtrType>
107  value_t*,
108  TPValPtrType,
109  TPAdjPtrType,
110  value_t*>& pack) const {
111  auto curr_pack = pack;
112  curr_pack.uc_val += i_pack_->off_pack.uc_offset;
113  curr_pack.uc_adj += i_pack_->off_pack.uc_offset;
114  curr_pack.c_val += i_pack_->off_pack.c_offset;
115  curr_pack.v_val += i_pack_->off_pack.v_offset;
116  return transformer_.logj_inv_transform_ad(curr_pack, pack);
117  }
118 
124  template <class GenType, class ContDist>
125  void init(GenType& gen, ContDist& dist)
126  { transformer_.init(gen, dist); }
127 
136  void activate(util::OffsetPack& pack) const {
137  i_pack_->off_pack = pack;
138  i_pack_->refcnt = 0;
139  pack.uc_offset += transformer_.bind_size_uc();
140  pack.c_offset += transformer_.bind_size_c();
141  pack.v_offset += transformer_.bind_size_v();
142  }
143 
152  void activate_refcnt() const {
153  ++i_pack_->refcnt;
154  transformer_.activate_refcnt(i_pack_->refcnt);
155  }
156 
163  template <class PtrPackType>
164  void bind(const PtrPackType& pack)
165  {
166  static_cast<void>(pack);
167  if constexpr (std::is_convertible_v<typename PtrPackType::uc_val_ptr_t, value_t*> &&
168  std::is_convertible_v<typename PtrPackType::c_val_ptr_t, value_t*>) {
169  value_t* ucp = pack.uc_val;
170  value_t* cp = pack.c_val;
171  size_t* vp = pack.v_val;
172  util::PtrPack curr_pack(ucp, nullptr, nullptr, nullptr, cp, vp);
173  curr_pack.uc_val += i_pack_->off_pack.uc_offset;
174  curr_pack.c_val += i_pack_->off_pack.c_offset;
175  curr_pack.v_val += i_pack_->off_pack.v_offset;
176  transformer_.bind(curr_pack, pack);
177  }
178  }
179 
180  var_t& get() { return transformer_.get_c(); }
181  const var_t& get() const { return transformer_.get_c(); }
182  constexpr size_t size() const { return transformer_.size_c(); }
183  constexpr size_t rows() const { return transformer_.rows_c(); }
184  constexpr size_t cols() const { return transformer_.cols_c(); }
185  id_t id() const { return id_; }
186 
187  // API specific to ParamView
188  auto& offset() { return i_pack_->off_pack; }
189  auto offset() const { return i_pack_->off_pack; }
190  constexpr size_t size_uc() const { return transformer_.size_uc(); }
191  constexpr size_t size_c() const { return transformer_.size_c(); }
192 
193 private:
194  details::ParamInfoPack* const i_pack_;
196  const id_t id_;
197 };
198 
208 template <class ValueType
209  , class ShapeType = ppl::scl
210  , class ConstraintType = expr::constraint::Unconstrained>
211 struct Param
212 {
213  static_assert(util::is_shape_v<ShapeType>,
215 };
216 
217 template <class ValueType, class ConstraintType>
218 struct Param<ValueType, ppl::scl, ConstraintType>:
219  ParamView<ValueType, ppl::scl, ConstraintType>,
220  util::ParamBase<Param<ValueType, ppl::scl, ConstraintType>>
221 {
223  using typename base_t::constraint_t;
224 
225  Param(size_t=1,
226  size_t=1,
227  constraint_t c = constraint_t()) noexcept
228  : base_t(&i_pack_, 1, 1, c)
229  {}
230 
231 private:
232  details::ParamInfoPack i_pack_;
233 };
234 
235 template <class ValueType, class ConstraintType>
236 struct Param<ValueType, ppl::vec, ConstraintType> :
237  ParamView<ValueType, ppl::vec, ConstraintType>,
238  util::ParamBase<Param<ValueType, ppl::vec, ConstraintType>>
239 {
241  using typename base_t::constraint_t;
242 
243  Param(size_t n,
244  size_t=1,
246  : base_t(&i_pack_, n, 1, c)
247  {}
248 
249 private:
250  details::ParamInfoPack i_pack_;
251 };
252 
253 template <class ValueType, class ConstraintType>
254 struct Param<ValueType, ppl::mat, ConstraintType> :
255  ParamView<ValueType, ppl::mat, ConstraintType>,
256  util::ParamBase<Param<ValueType, ppl::mat, ConstraintType>>
257 {
259  using typename base_t::constraint_t;
260 
261  Param(size_t rows, size_t cols, constraint_t c = constraint_t())
262  : base_t(&i_pack_, rows, cols, c)
263  {}
264 
265 private:
266  details::ParamInfoPack i_pack_;
267 };
268 
269 
270 // Helper function to create a Param object and deduce the constraint expression.
271 
272 template <class ValueType
273  , class ShapeType = scl
274  , class ConstraintType = expr::constraint::Unconstrained>
275 constexpr inline auto make_param(size_t rows,
276  size_t cols,
277  const ConstraintType& c = ConstraintType())
278 {
280 }
281 
282 template <class ValueType
283  , class ShapeType = scl
284  , class ConstraintType = expr::constraint::Unconstrained>
285 constexpr inline auto make_param(size_t rows,
286  const ConstraintType& c = ConstraintType())
287 {
289 }
290 
291 template <class ValueType
292  , class ShapeType = scl
293  , class ConstraintType = expr::constraint::Unconstrained>
294 constexpr inline auto make_param(const ConstraintType& c = ConstraintType())
295 {
297 }
298 
299 } // namespace ppl
300 
301 #undef PPL_PARAMVIEW_SHAPE_UNSUPPORTED
302 #undef PPL_PARAM_SHAPE_UNSUPPORTED
ppl::Param< ValueType, ppl::vec, ConstraintType >::Param
Param(size_t n, size_t=1, constraint_t c=constraint_t())
Definition: param.hpp:243
ppl::util::OffsetPack::c_offset
index_t c_offset
Definition: offset_pack.hpp:12
ppl::util::ParamBase
Definition: var_traits.hpp:16
ppl::ParamView::activate_refcnt
void activate_refcnt() const
Definition: param.hpp:152
ppl::expr::constraint::Transformer< value_t, shape_t, constraint_t >
var_traits.hpp
ppl::Param
Definition: param.hpp:212
ppl::ParamView::size_uc
constexpr size_t size_uc() const
Definition: param.hpp:190
ppl::ParamView::id_t
const void * id_t
Definition: param.hpp:52
ppl::ParamView< ValueType, ppl::vec, ConstraintType >::var_t
util::var_t< value_t, shape_t > var_t
Definition: param.hpp:51
ppl::util::PtrPack::c_val
CValPtrType c_val
Definition: ptr_pack.hpp:37
ppl::make_param
constexpr auto make_param(size_t rows, size_t cols, const ConstraintType &c=ConstraintType())
Definition: param.hpp:275
ppl::ParamView::offset
auto & offset()
Definition: param.hpp:188
ppl::ParamView::traverse
void traverse(Func &&) const
Definition: param.hpp:65
ppl::ParamView< ValueType, ppl::vec, ConstraintType >::shape_t
ppl::vec shape_t
Definition: param.hpp:49
ppl::details::ParamInfoPack::off_pack
util::OffsetPack off_pack
Definition: param.hpp:20
shape_traits.hpp
ppl::ParamView::eval
const var_t & eval()
Definition: param.hpp:76
PPL_PARAM_SHAPE_UNSUPPORTED
#define PPL_PARAM_SHAPE_UNSUPPORTED
Definition: param.hpp:11
ppl::vec
ad::vec vec
Definition: shape_traits.hpp:17
ppl::expr::constraint::Unconstrained
Definition: unconstrained.hpp:11
ppl::ParamView::init
void init(GenType &gen, ContDist &dist)
Definition: param.hpp:125
ptr_pack.hpp
offset_pack.hpp
ppl::ParamView::offset
auto offset() const
Definition: param.hpp:189
unconstrained.hpp
ppl::util::OffsetPack::v_offset
index_t v_offset
Definition: offset_pack.hpp:13
ppl::ParamView::size_c
constexpr size_t size_c() const
Definition: param.hpp:191
ppl::ParamView::ad
auto ad(const util::PtrPack< value_t *, value_t *, TPValPtrType, TPAdjPtrType, value_t * > &pack) const
Definition: param.hpp:90
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::ParamView::logj_ad
auto logj_ad(const util::PtrPack< value_t *, value_t *, TPValPtrType, TPAdjPtrType, value_t * > &pack) const
Definition: param.hpp:106
ppl::scl
ad::scl scl
Definition: shape_traits.hpp:16
ppl::details::ParamInfoPack
Definition: param.hpp:18
ppl::ParamView::has_param
static constexpr bool has_param
Definition: param.hpp:53
ppl::ParamView::ParamView
ParamView(details::ParamInfoPack *i_pack, size_t rows=1, size_t cols=1, constraint_t c=constraint_t()) noexcept
Definition: param.hpp:55
ppl::ParamView::cols
constexpr size_t cols() const
Definition: param.hpp:184
ppl::util::PtrPack::uc_val
UCValPtrType uc_val
Definition: ptr_pack.hpp:33
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ppl::ParamView::inv_eval
void inv_eval()
Definition: param.hpp:86
ppl::ParamView::activate
void activate(util::OffsetPack &pack) const
Definition: param.hpp:136
ppl::Param< ValueType, ppl::scl, ConstraintType >::Param
Param(size_t=1, size_t=1, constraint_t c=constraint_t()) noexcept
Definition: param.hpp:225
ppl::ParamView::get
const var_t & get() const
Definition: param.hpp:181
ppl::ParamView::rows
constexpr size_t rows() const
Definition: param.hpp:183
ppl::util::OffsetPack::uc_offset
index_t uc_offset
Definition: offset_pack.hpp:11
ppl::ParamView::size
constexpr size_t size() const
Definition: param.hpp:182
ppl::ParamView< ValueType, ppl::vec, ConstraintType >::constraint_t
ConstraintType constraint_t
Definition: param.hpp:50
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::ParamView< ValueType, ppl::vec, ConstraintType >::value_t
ValueType value_t
Definition: param.hpp:48
ppl::mat
ad::mat mat
Definition: shape_traits.hpp:18
ppl::Param< ValueType, ppl::mat, ConstraintType >::Param
Param(size_t rows, size_t cols, constraint_t c=constraint_t())
Definition: param.hpp:261
ppl::ParamView
Definition: param.hpp:47
ppl::ParamView::get
var_t & get()
Definition: param.hpp:180
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::util::OffsetPack
Definition: offset_pack.hpp:9
ppl::details::ParamInfoPack::refcnt
size_t refcnt
Definition: param.hpp:19
ppl::util::PtrPack
Definition: ptr_pack.hpp:14
ppl::ParamView::id
id_t id() const
Definition: param.hpp:185
ppl::ParamView::bind
void bind(const PtrPackType &pack)
Definition: param.hpp:164
ppl::util::PtrPack::v_val
size_t * v_val
Definition: ptr_pack.hpp:38