autoppl  v0.8
A C++ template library for probabilistic programming
unconstrained.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/util/shape_traits.hpp>
3 #include <fastad_bits/reverse/core/var_view.hpp>
5 #include <autoppl/util/value.hpp>
6 
7 namespace ppl {
8 namespace expr {
9 namespace constraint {
10 
11 struct Unconstrained {};
12 
13 template <class ValueType
14  , class ShapeType>
15 struct Transformer<ValueType, ShapeType, Unconstrained>
16 {
17  using value_t = ValueType;
18  using shape_t = ShapeType;
21  using view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
22 
29  Transformer(size_t rows,
30  size_t cols,
32  : uc_val_(util::make_val<value_t, shape_t>(rows, cols))
33  {
34  // TODO: remove?
35  //if constexpr (!util::is_scl_v<std::decay_t<decltype(*this)>>) {
36  // util::bind(uc_val_, nullptr, rows, cols);
37  //}
38  }
39 
40  void transform() {}
41  void inv_transform(size_t) {}
42 
43  template <class CurrPtrPack, class PtrPack>
44  auto inv_transform_ad(const CurrPtrPack& curr_pack,
45  const PtrPack&,
46  size_t) const {
47  return ad::VarView<value_t, shape_t>(
48  curr_pack.uc_val, curr_pack.uc_adj, rows_uc(), cols_uc());
49  }
50 
51  template <class CurrPtrPack, class PtrPack>
52  auto logj_inv_transform_ad(const CurrPtrPack&,
53  const PtrPack&) const {
54  return ad::constant(0.);
55  }
56 
57  template <class GenType, class ContDist>
58  void init(GenType& gen, ContDist& dist) {
59  static_cast<void>(gen);
60  static_cast<void>(dist);
61  if constexpr (util::is_disc_v<value_t>) {
62  if constexpr (std::is_same_v<shape_t, scl>) {
63  *uc_val_ = 0;
64  } else {
65  uc_val_.setZero();
66  }
67  } else {
68  if constexpr (std::is_same_v<shape_t, scl>) {
69  *uc_val_ = dist(gen);
70  } else {
71  uc_val_ = var_t::NullaryExpr(rows_uc(), cols_uc(), [&]() { return dist(gen); });
72  }
73  }
74  }
75 
76  void activate_refcnt(size_t) const {}
77 
78  var_t& get_c() { return util::get(uc_val_); }
79  const var_t& get_c() const { return util::get(uc_val_); }
80 
81  constexpr size_t size_uc() const { return util::size(uc_val_); }
82  constexpr size_t rows_uc() const { return util::rows(uc_val_); }
83  constexpr size_t cols_uc() const { return util::cols(uc_val_); }
84  constexpr size_t size_c() const { return size_uc(); }
85  constexpr size_t rows_c() const { return rows_uc(); }
86  constexpr size_t cols_c() const { return cols_uc(); }
87 
88  constexpr size_t bind_size_uc() const { return size_uc(); }
89  constexpr size_t bind_size_c() const { return 0; }
90  constexpr size_t bind_size_v() const { return 0; }
91 
95  template <class CurrPtrPack, class PtrPack>
96  void bind(const CurrPtrPack& curr_pack,
97  const PtrPack&)
98  { util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc()); }
99 
100 private:
101  view_t uc_val_;
102 };
103 
104 } // namespace constraint
105 } // namespace expr
106 } // namespace ppl
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::bind_size_c
constexpr size_t bind_size_c() const
Definition: unconstrained.hpp:89
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::view_t
ad::util::shape_to_raw_view_t< value_t, shape_t > view_t
Definition: unconstrained.hpp:21
ppl::expr::constraint::Transformer
Definition: transformer.hpp:10
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::logj_inv_transform_ad
auto logj_inv_transform_ad(const CurrPtrPack &, const PtrPack &) const
Definition: unconstrained.hpp:52
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::bind_size_v
constexpr size_t bind_size_v() const
Definition: unconstrained.hpp:90
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::get_c
const var_t & get_c() const
Definition: unconstrained.hpp:79
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::transform
void transform()
Definition: unconstrained.hpp:40
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::rows_c
constexpr size_t rows_c() const
Definition: unconstrained.hpp:85
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
ppl::expr::constraint::Unconstrained
Definition: unconstrained.hpp:11
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::var_t
util::var_t< value_t, shape_t > var_t
Definition: unconstrained.hpp:20
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::cols_uc
constexpr size_t cols_uc() const
Definition: unconstrained.hpp:83
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::bind
void bind(const CurrPtrPack &curr_pack, const PtrPack &)
Definition: unconstrained.hpp:96
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::inv_transform_ad
auto inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &, size_t) const
Definition: unconstrained.hpp:44
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::init
void init(GenType &gen, ContDist &dist)
Definition: unconstrained.hpp:58
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::size_c
constexpr size_t size_c() const
Definition: unconstrained.hpp:84
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::bind_size_uc
constexpr size_t bind_size_uc() const
Definition: unconstrained.hpp:88
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::Transformer
Transformer(size_t rows, size_t cols, constraint_t=constraint_t())
Definition: unconstrained.hpp:29
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::size_uc
constexpr size_t size_uc() const
Definition: unconstrained.hpp:81
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::shape_t
ShapeType shape_t
Definition: unconstrained.hpp:18
value.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::value_t
ValueType value_t
Definition: unconstrained.hpp:17
ppl
Definition: bounded.hpp:11
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::inv_transform
void inv_transform(size_t)
Definition: unconstrained.hpp:41
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::get_c
var_t & get_c()
Definition: unconstrained.hpp:78
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::activate_refcnt
void activate_refcnt(size_t) const
Definition: unconstrained.hpp:76
ppl::util::make_val
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9
transformer.hpp
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::rows_uc
constexpr size_t rows_uc() const
Definition: unconstrained.hpp:82
ppl::expr::constraint::Transformer< ValueType, ShapeType, Unconstrained >::cols_c
constexpr size_t cols_c() const
Definition: unconstrained.hpp:86