autoppl  v0.8
A C++ template library for probabilistic programming
bounded.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <cstddef>
3 #include <cmath>
4 #include <fastad_bits/util/shape_traits.hpp>
5 #include <fastad_bits/reverse/core/var_view.hpp>
9 #include <autoppl/util/value.hpp>
10 
11 namespace ppl {
12 namespace expr {
13 namespace constraint {
14 
15 template <class LowerType, class UpperType>
16 struct Bounded
17 {
18 private:
19  using lower_t = LowerType;
20  using upper_t = UpperType;
21 
22  static_assert(util::is_var_expr_v<lower_t>);
23  static_assert(util::is_var_expr_v<upper_t>);
24  static_assert(util::is_scl_v<lower_t> ||
25  util::is_scl_v<upper_t> ||
26  std::is_same_v<
29 
30 public:
31 
32  Bounded(const lower_t& lower,
33  const upper_t& upper)
34  : lower_{lower}
35  , upper_{upper}
36  {}
37 
41  template <class T>
42  constexpr void transform(const T& c,
43  T& uc) const
44  {
45  if constexpr (std::is_arithmetic_v<T>) {
46  uc = std::log((c - lower_.get()) / (upper_.get() - c));
47  } else {
48  auto ca = c.array();
49  auto alower = util::to_array(lower_.get());
50  auto aupper = util::to_array(upper_.get());
51  uc = (ca - alower) / (aupper - ca).log().matrix();
52  }
53  }
54 
58  template <class T>
59  constexpr void inv_transform(const T& uc,
60  T& c)
61  { ad::boost::bounded_inv_transform(uc, lower_.eval(), upper_.eval(), c); }
62 
63  template <class UCViewType
64  , class CurrPtrPackType
65  , class PtrPackType>
66  auto inv_transform_ad(const UCViewType& uc_view,
67  const CurrPtrPackType& curr_pack,
68  const PtrPackType& pack,
69  size_t refcnt) const
70  {
71  auto&& lower = lower_.ad(pack);
72  auto&& upper = upper_.ad(pack);
73  return ad::boost::BoundedInvTransformNode(uc_view,
74  lower,
75  upper,
76  curr_pack.c_val,
77  curr_pack.v_val,
78  refcnt);
79  }
80 
81  template <class UCViewType
82  , class CurrPtrPack
83  , class PtrPack>
84  auto logj_inv_transform_ad(const UCViewType& uc_view,
85  const CurrPtrPack& curr_pack,
86  const PtrPack& pack) const
87  {
88  auto lower = lower_.ad(pack);
89  auto upper = upper_.ad(pack);
91  lower,
92  upper,
93  curr_pack.c_val);
94  }
95 
96  void activate_refcnt() const {
97  lower_.activate_refcnt();
98  upper_.activate_refcnt();
99  }
100 
101  template <class PtrPack>
102  void bind(const PtrPack& pack) {
103  if constexpr (lower_t::has_param) {
104  lower_.bind(pack);
105  }
106  if constexpr (upper_t::has_param) {
107  upper_.bind(pack);
108  }
109  }
110 
111 private:
112  lower_t lower_;
113  upper_t upper_;
114 };
115 
116 // Specialization: Lower (scalar)
117 template <class ValueType
118  , class ShapeType
119  , class LowerType
120  , class UpperType>
121 struct Transformer<ValueType, ShapeType, Bounded<LowerType, UpperType>>
122 {
123  using value_t = ValueType;
124  using shape_t = ShapeType;
127  using uc_view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
128  using view_t = uc_view_t;
129 
130  // only continuous value types can be constrained
131  static_assert(util::is_cont_v<value_t>);
132 
139  Transformer(size_t rows,
140  size_t cols,
141  constraint_t c)
142  : uc_val_(util::make_val<value_t, shape_t>(rows, cols))
143  , c_val_(util::make_val<value_t, shape_t>(rows, cols))
144  , v_val_(nullptr)
145  , constraint_(c)
146  {}
147 
148  void transform() {
149  constraint_.transform(util::get(c_val_), util::get(uc_val_));
150  }
151 
158  void inv_transform(size_t refcnt) {
159  ++*v_val_;
160  if (*v_val_ == 1) {
161  constraint_.inv_transform(util::get(uc_val_), util::get(c_val_));
162  }
163  *v_val_ = *v_val_ % refcnt;
164  }
165 
174  //template <class UCType, class CType>
175  //void inv_transform(const UCType& uc,
176  // CType& c) const
177  //{ constraint_.inv_transform(uc(0,0), c(0,0)); }
178 
189  template <class CurrPtrPack, class PtrPack>
190  auto inv_transform_ad(const CurrPtrPack& curr_pack,
191  const PtrPack& pack,
192  size_t refcnt) const {
193  ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
194  curr_pack.uc_adj,
195  rows_uc(),
196  cols_uc());
197  return constraint_.inv_transform_ad(uc_view, curr_pack, pack, refcnt);
198  }
199 
206  template <class CurrPtrPack, class PtrPack>
207  auto logj_inv_transform_ad(const CurrPtrPack& curr_pack,
208  const PtrPack& pack) const {
209  ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
210  curr_pack.uc_adj,
211  rows_uc(),
212  cols_uc());
213  return constraint_.logj_inv_transform_ad(uc_view, curr_pack, pack);
214  }
215 
219  template <class GenType, class ContDist>
220  void init(GenType& gen, ContDist& dist) {
221  if constexpr (std::is_same_v<shape_t, scl>) {
222  util::get(uc_val_) = dist(gen);
223  } else {
224  util::get(uc_val_) = var_t::NullaryExpr(rows_uc(), cols_uc(),
225  [&](size_t, size_t) { return dist(gen); });
226  }
227  }
228 
229  void activate_refcnt(size_t curr_refcnt) const {
230  if (curr_refcnt == 1) constraint_.activate_refcnt();
231  }
232 
233  var_t& get_c() { return util::get(c_val_); }
234  const var_t& get_c() const { return util::get(c_val_); }
235 
240  constexpr size_t size_uc() const { return util::size(uc_val_); }
241  constexpr size_t rows_uc() const { return util::rows(uc_val_); }
242  constexpr size_t cols_uc() const { return util::cols(uc_val_); }
243  constexpr size_t size_c() const { return util::size(c_val_); }
244  constexpr size_t rows_c() const { return util::rows(c_val_); }
245  constexpr size_t cols_c() const { return util::cols(c_val_); }
246 
251  constexpr size_t bind_size_uc() const { return size_uc(); }
252  constexpr size_t bind_size_c() const { return size_c(); }
253  constexpr size_t bind_size_v() const { return 1; }
254 
260  template <class CurrPtrPack, class PtrPack>
261  void bind(const CurrPtrPack& curr_pack,
262  const PtrPack& pack)
263  {
264  util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
265  util::bind(c_val_, curr_pack.c_val, rows_c(), cols_c());
266  util::bind(v_val_, curr_pack.v_val, 1, 1);
267  constraint_.bind(pack);
268  }
269 
270 private:
271  uc_view_t uc_val_;
272  view_t c_val_;
273  size_t* v_val_;
274  constraint_t constraint_;
275 };
276 
277 } // namespace constraint
278 } // namespace expr
279 
280 
281 template <class LowerType
282  , class UpperType>
283 constexpr inline auto bounded(const LowerType& lower,
284  const UpperType& upper)
285 {
286  using lower_t = util::convert_to_param_t<LowerType>;
287  using upper_t = util::convert_to_param_t<UpperType>;
288  lower_t wrap_lower = lower;
289  upper_t wrap_upper = upper;
290  return expr::constraint::Bounded(wrap_lower, wrap_upper);
291 }
292 
293 } // namespace ppl
294 
ppl::expr::constraint::Transformer
Definition: transformer.hpp:10
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::rows_uc
constexpr size_t rows_uc() const
Definition: bounded.hpp:241
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::inv_transform
void inv_transform(size_t refcnt)
Definition: bounded.hpp:158
bounded_inv_transform.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::activate_refcnt
void activate_refcnt(size_t curr_refcnt) const
Definition: bounded.hpp:229
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::uc_view_t
ad::util::shape_to_raw_view_t< value_t, shape_t > uc_view_t
Definition: bounded.hpp:127
ad::boost::bounded_inv_transform
constexpr void bounded_inv_transform(const UCType &uc, const LowerType &lower, const UpperType &upper, CType &c)
Definition: bounded_inv_transform.hpp:16
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::get_c
const var_t & get_c() const
Definition: bounded.hpp:234
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::cols_c
constexpr size_t cols_c() const
Definition: bounded.hpp:245
ppl::expr::constraint::Bounded::activate_refcnt
void activate_refcnt() const
Definition: bounded.hpp:96
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ppl::expr::constraint::Bounded::logj_inv_transform_ad
auto logj_inv_transform_ad(const UCViewType &uc_view, const CurrPtrPack &curr_pack, const PtrPack &pack) const
Definition: bounded.hpp:84
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::rows_c
constexpr size_t rows_c() const
Definition: bounded.hpp:244
ppl::lower
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::init
void init(GenType &gen, ContDist &dist)
Definition: bounded.hpp:220
ppl::bounded
constexpr auto bounded(const LowerType &lower, const UpperType &upper)
Definition: bounded.hpp:283
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::get_c
var_t & get_c()
Definition: bounded.hpp:233
ad::boost::BoundedInvTransformNode
Definition: bounded_inv_transform.hpp:36
ppl::expr::constraint::Bounded::bind
void bind(const PtrPack &pack)
Definition: bounded.hpp:102
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
value.hpp
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::size_c
constexpr size_t size_c() const
Definition: bounded.hpp:243
ppl::expr::constraint::Bounded::inv_transform
constexpr void inv_transform(const T &uc, T &c)
Definition: bounded.hpp:59
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::Transformer
Transformer(size_t rows, size_t cols, constraint_t c)
Definition: bounded.hpp:139
ppl::expr::constraint::Bounded::Bounded
Bounded(const lower_t &lower, const upper_t &upper)
Definition: bounded.hpp:32
ppl::util::to_array
constexpr auto to_array(const T &x)
Definition: value.hpp:74
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::value_t
ValueType value_t
Definition: bounded.hpp:123
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::transform
void transform()
Definition: bounded.hpp:148
ppl
Definition: bounded.hpp:11
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::inv_transform_ad
auto inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &pack, size_t refcnt) const
Definition: bounded.hpp:190
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::expr::constraint::Bounded::inv_transform_ad
auto inv_transform_ad(const UCViewType &uc_view, const CurrPtrPackType &curr_pack, const PtrPackType &pack, size_t refcnt) const
Definition: bounded.hpp:66
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::view_t
uc_view_t view_t
Definition: bounded.hpp:128
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::shape_t
ShapeType shape_t
Definition: bounded.hpp:124
ad::boost::LogJBoundedInvTransformNode
Definition: bounded_inv_transform.hpp:151
traits.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::bind_size_v
constexpr size_t bind_size_v() const
Definition: bounded.hpp:253
ppl::util::make_val
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9
transformer.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::cols_uc
constexpr size_t cols_uc() const
Definition: bounded.hpp:242
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::bind_size_uc
constexpr size_t bind_size_uc() const
Definition: bounded.hpp:251
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::bind_size_c
constexpr size_t bind_size_c() const
Definition: bounded.hpp:252
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::size_uc
constexpr size_t size_uc() const
Definition: bounded.hpp:240
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::bind
void bind(const CurrPtrPack &curr_pack, const PtrPack &pack)
Definition: bounded.hpp:261
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::logj_inv_transform_ad
auto logj_inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &pack) const
Definition: bounded.hpp:207
ppl::expr::constraint::Bounded::transform
constexpr void transform(const T &c, T &uc) const
Definition: bounded.hpp:42
ppl::expr::constraint::Bounded
Definition: bounded.hpp:17
ppl::expr::constraint::Transformer< ValueType, ShapeType, Bounded< LowerType, UpperType > >::var_t
util::var_t< value_t, shape_t > var_t
Definition: bounded.hpp:125