autoppl  v0.8
A C++ template library for probabilistic programming
lower.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <cstddef>
3 #include <cmath>
4 #include <fastad_bits/util/shape_traits.hpp>
5 #include <fastad_bits/reverse/core/var_view.hpp>
6 #include <fastad_bits/reverse/core/sum.hpp>
11 #include <autoppl/util/value.hpp>
12 
13 namespace ppl {
14 namespace expr {
15 namespace constraint {
16 
17 template <class ExprType>
18 struct Lower
19 {
20 private:
21  using lower_t = ExprType;
22  static_assert(util::is_var_expr_v<lower_t>);
23 
24 public:
27 
28  Lower(const lower_t& lower): lower_{lower} {}
29 
34  template <class T>
35  constexpr void transform(const T& c,
36  T& uc) const
37  {
38  if constexpr (std::is_arithmetic_v<T>) {
39  uc = std::log(c - lower_.get());
40  } else {
41  if constexpr (util::is_scl_v<lower_t>) {
42  uc = (c.array() - lower_.get()).log().matrix();
43  } else {
44  uc = (c.array() - lower_.get().array()).log().matrix();
45  }
46  }
47  }
48 
54  template <class T>
55  constexpr void inv_transform(const T& uc,
56  T& c)
57  { ad::boost::lower_inv_transform(uc, lower_.eval(), c); }
58 
59  template <class UCViewType
60  , class CurrPtrPackType
61  , class PtrPackType>
62  auto inv_transform_ad(const UCViewType& uc_view,
63  const CurrPtrPackType& curr_pack,
64  const PtrPackType& pack,
65  size_t refcnt) const
66  {
67  auto lower = lower_.ad(pack);
68  return ad::boost::LowerInvTransformNode(uc_view,
69  lower,
70  curr_pack.c_val,
71  curr_pack.v_val,
72  refcnt);
73  }
74 
75  template <class UCViewType>
76  auto logj_inv_transform_ad(const UCViewType& uc_view) const
77  {
78  return ad::sum(uc_view);
79  }
80 
81  void activate_refcnt() const { lower_.activate_refcnt(); }
82 
83  template <class PtrPack>
84  void bind(const PtrPack& pack) {
85  if constexpr (lower_t::has_param) {
86  lower_.bind(pack);
87  }
88  }
89 
90 private:
91  lower_t lower_;
92 };
93 
94 // Specialization: Lower
95 template <class ValueType, class ShapeType, class ExprType>
96 struct Transformer<ValueType, ShapeType, Lower<ExprType>>
97 {
98  using value_t = ValueType;
99  using shape_t = ShapeType;
102  using uc_view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
103  using view_t = uc_view_t;
104 
105  // only continuous value types can be constrained
106  static_assert(util::is_cont_v<value_t>);
107  static_assert(util::is_scl_v<ExprType> ||
108  std::is_same_v<shape_t,
110 
117  Transformer(size_t rows,
118  size_t cols,
119  const constraint_t& c)
120  : uc_val_(util::make_val<value_t, shape_t>(rows, cols))
121  , c_val_(util::make_val<value_t, shape_t>(rows, cols))
122  , v_val_(nullptr)
123  , constraint_(c)
124  {}
125 
132  void transform() {
133  constraint_.transform(util::get(c_val_), util::get(uc_val_));
134  }
135 
142  void inv_transform(size_t refcnt) {
143  ++*v_val_;
144  if (*v_val_ == 1) {
145  constraint_.inv_transform(util::get(uc_val_), util::get(c_val_));
146  }
147  *v_val_ = *v_val_ % refcnt;
148  }
149 
160  template <class CurrPtrPack, class PtrPack>
161  auto inv_transform_ad(const CurrPtrPack& curr_pack,
162  const PtrPack& pack,
163  size_t refcnt) const {
164  ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
165  curr_pack.uc_adj,
166  rows_uc(),
167  cols_uc());
168  return constraint_.inv_transform_ad(uc_view, curr_pack, pack, refcnt);
169  }
170 
175  template <class CurrPtrPack, class PtrPack>
176  auto logj_inv_transform_ad(const CurrPtrPack& curr_pack,
177  const PtrPack&) const {
178  ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
179  curr_pack.uc_adj,
180  rows_uc(),
181  cols_uc());
182  return constraint_.logj_inv_transform_ad(uc_view);
183  }
184 
188  template <class GenType, class ContDist>
189  void init(GenType& gen, ContDist& dist) {
190  if constexpr (std::is_same_v<shape_t, scl>) {
191  util::get(uc_val_) = dist(gen);
192  } else {
193  util::get(uc_val_) = var_t::NullaryExpr(rows_uc(), cols_uc(),
194  [&](size_t, size_t) { return dist(gen); });
195  }
196  }
197 
204  void activate_refcnt(size_t curr_refcnt) const {
205  if (curr_refcnt == 1) constraint_.activate_refcnt();
206  }
207 
208  var_t& get_c() { return util::get(c_val_); }
209  const var_t& get_c() const { return util::get(c_val_); }
210 
215  constexpr size_t size_uc() const { return util::size(uc_val_); }
216  constexpr size_t rows_uc() const { return util::rows(uc_val_); }
217  constexpr size_t cols_uc() const { return util::cols(uc_val_); }
218  constexpr size_t size_c() const { return util::size(c_val_); }
219  constexpr size_t rows_c() const { return util::rows(c_val_); }
220  constexpr size_t cols_c() const { return util::cols(c_val_); }
221 
226  constexpr size_t bind_size_uc() const { return size_uc(); }
227  constexpr size_t bind_size_c() const { return size_c(); }
228  constexpr size_t bind_size_v() const { return 1; }
229 
235  template <class CurrPtrPack, class PtrPack>
236  void bind(const CurrPtrPack& curr_pack,
237  const PtrPack& pack)
238  {
239  util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
240  util::bind(c_val_, curr_pack.c_val, rows_c(), cols_c());
241  util::bind(v_val_, curr_pack.v_val, 1, 1);
242  constraint_.bind(pack);
243  }
244 
245 private:
246  uc_view_t uc_val_;
247  view_t c_val_;
248  size_t* v_val_;
249  constraint_t constraint_;
250 };
251 
252 } // namespace constraint
253 } // namespace expr
254 
255 template <class LowerType>
256 constexpr inline auto lower(const LowerType& expr)
257 {
258  using lower_t = util::convert_to_param_t<LowerType>;
259  lower_t wrap_expr = expr;
260  return expr::constraint::Lower(wrap_expr);
261 }
262 
263 } // namespace ppl
264 
lower_inv_transform.hpp
ppl::expr::constraint::Transformer
Definition: transformer.hpp:10
ppl::expr::constraint::Lower::inv_transform
constexpr void inv_transform(const T &uc, T &c)
Definition: lower.hpp:55
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::size_uc
constexpr size_t size_uc() const
Definition: lower.hpp:215
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::var_t
util::var_t< value_t, shape_t > var_t
Definition: lower.hpp:100
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
ppl::expr::constraint::Lower::Lower
Lower(const lower_t &lower)
Definition: lower.hpp:28
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::Transformer
Transformer(size_t rows, size_t cols, const constraint_t &c)
Definition: lower.hpp:117
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::inv_transform_ad
auto inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &pack, size_t refcnt) const
Definition: lower.hpp:161
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::get_c
const var_t & get_c() const
Definition: lower.hpp:209
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::value_t
ValueType value_t
Definition: lower.hpp:98
ppl::expr::constraint::Lower::bind
void bind(const PtrPack &pack)
Definition: lower.hpp:84
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::logj_inv_transform_ad
auto logj_inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &) const
Definition: lower.hpp:176
ad::boost::sum
auto sum(const T &x)
Definition: value.hpp:8
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::cols_uc
constexpr size_t cols_uc() const
Definition: lower.hpp:217
dist_expr_traits.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::view_t
uc_view_t view_t
Definition: lower.hpp:103
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::bind
void bind(const CurrPtrPack &curr_pack, const PtrPack &pack)
Definition: lower.hpp:236
ppl::expr::constraint::Lower::value_t
typename util::var_expr_traits< lower_t >::value_t value_t
Definition: lower.hpp:25
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::rows_c
constexpr size_t rows_c() const
Definition: lower.hpp:219
ppl::lower
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ppl::expr::constraint::Lower::transform
constexpr void transform(const T &c, T &uc) const
Definition: lower.hpp:35
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::activate_refcnt
void activate_refcnt(size_t curr_refcnt) const
Definition: lower.hpp:204
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::rows_uc
constexpr size_t rows_uc() const
Definition: lower.hpp:216
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::init
void init(GenType &gen, ContDist &dist)
Definition: lower.hpp:189
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::transform
void transform()
Definition: lower.hpp:132
ppl::expr::constraint::Lower::shape_t
typename util::shape_traits< lower_t >::shape_t shape_t
Definition: lower.hpp:26
ppl::expr::constraint::Lower::logj_inv_transform_ad
auto logj_inv_transform_ad(const UCViewType &uc_view) const
Definition: lower.hpp:76
value.hpp
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl::expr::constraint::Lower::activate_refcnt
void activate_refcnt() const
Definition: lower.hpp:81
ad::boost::LowerInvTransformNode
Definition: lower_inv_transform.hpp:37
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::cols_c
constexpr size_t cols_c() const
Definition: lower.hpp:220
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::get_c
var_t & get_c()
Definition: lower.hpp:208
ppl::expr::constraint::Lower::inv_transform_ad
auto inv_transform_ad(const UCViewType &uc_view, const CurrPtrPackType &curr_pack, const PtrPackType &pack, size_t refcnt) const
Definition: lower.hpp:62
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::bind_size_uc
constexpr size_t bind_size_uc() const
Definition: lower.hpp:226
ad::boost::lower_inv_transform
constexpr void lower_inv_transform(const UCType &uc, const LowerType &lower, CType &c)
Definition: lower_inv_transform.hpp:13
ppl
Definition: bounded.hpp:11
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::uc_view_t
ad::util::shape_to_raw_view_t< value_t, shape_t > uc_view_t
Definition: lower.hpp:102
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::bind_size_v
constexpr size_t bind_size_v() const
Definition: lower.hpp:228
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::shape_t
ShapeType shape_t
Definition: lower.hpp:99
traits.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::inv_transform
void inv_transform(size_t refcnt)
Definition: lower.hpp:142
ppl::util::make_val
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9
transformer.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::bind_size_c
constexpr size_t bind_size_c() const
Definition: lower.hpp:227
ppl::expr::constraint::Transformer< ValueType, ShapeType, Lower< ExprType > >::size_c
constexpr size_t size_c() const
Definition: lower.hpp:218
ppl::expr::constraint::Lower
Definition: lower.hpp:19