autoppl  v0.8
A C++ template library for probabilistic programming
pos_def.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <cstddef>
3 #include <cmath>
4 #include <Eigen/Dense>
5 #include <fastad_bits/util/shape_traits.hpp>
6 #include <fastad_bits/reverse/core/var_view.hpp>
8 #include <autoppl/util/value.hpp>
10 
11 namespace ppl {
12 namespace expr {
13 namespace constraint {
14 
15 struct PosDef {
16 
21  static constexpr size_t size(size_t rows)
22  { return (rows * (rows + 1)) / 2; }
23 
27  template <class CType, class UCType>
28  static constexpr void transform(const CType& c,
29  UCType& uc)
30  {
31  using value_t = typename CType::Scalar;
32  using mat_t = Eigen::Matrix<value_t, Eigen::Dynamic, Eigen::Dynamic>;
33  Eigen::LLT<mat_t> llt(c);
34  mat_t lower = llt.matrixL();
35  lower.diagonal().array() = lower.diagonal().array().log();
36 
37  size_t k = 0;
38  for (int j = 0; j < lower.cols(); ++j) {
39  for (int i = j; i < lower.rows(); ++i, ++k) {
40  uc(k) = lower(i,j);
41  }
42  }
43  }
44 
52  template <class LowerType, class UCType, class CType>
53  static constexpr void inv_transform(LowerType& lower,
54  const UCType& uc,
55  CType& c)
57 };
58 
59 // Specialization: Positive-Definite (matrix)
60 template <class ValueType>
61 struct Transformer<ValueType, mat, PosDef>
62 {
63  using value_t = ValueType;
64  using shape_t = mat;
67  using uc_view_t = ad::util::shape_to_raw_view_t<value_t, vec>;
68  using view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
69 
70  // only continuous value types can be constrained
71  static_assert(util::is_cont_v<value_t>);
72 
80  Transformer(size_t rows,
81  size_t,
83  : uc_val_(nullptr, constraint_t::size(rows))
84  , c_val_(nullptr, rows, rows)
85  , lower_(nullptr, rows, rows)
86  , v_val_(nullptr)
87  {}
88 
89  void transform() {
90  constraint_t::transform(c_val_, uc_val_);
91  }
92 
101  void inv_transform(size_t refcnt) {
102  ++*v_val_;
103  if (*v_val_ == 1) {
104  constraint_t::inv_transform(lower_, uc_val_, c_val_);
105  }
106  *v_val_ = *v_val_ % refcnt;
107  }
108 
119  //template <class UCVecType, class CVecType>
120  //void inv_transform(const UCVecType& uc,
121  // CVecType& c) const
122  //{
123  // assert(static_cast<size_t>(uc.size()) == size_uc());
124  // assert(static_cast<size_t>(c.size()) == size_c());
125  // assert(uc.rows() == 1 || uc.cols() == 1);
126  // assert(c.rows() == 1 || c.cols() == 1);
127 
128  // using vec_t = Eigen::Matrix<value_t, Eigen::Dynamic, 1>;
129  // using mat_t = Eigen::Matrix<value_t, Eigen::Dynamic, Eigen::Dynamic>;
130  // mat_t uc_vec = uc;
131  // mat_t c_vec = c;
132  // mat_t lower(rows_c(), rows_c());
133  // lower.setZero();
134  // Eigen::Map<vec_t> uc_mp(uc_vec.data(), uc.size());
135  // Eigen::Map<mat_t> c_mp(c_vec.data(), rows_c(), rows_c());
136  // constraint_t::inv_transform(lower, uc_vec, c_mp);
137  // c = c_vec;
138  //}
139 
152  template <class CurrPtrPack, class PtrPack>
153  auto inv_transform_ad(const CurrPtrPack& curr_pack,
154  const PtrPack&,
155  size_t refcnt) const {
156  ad::VarView<value_t, ad::vec> uc_view(curr_pack.uc_val,
157  curr_pack.uc_adj,
158  size_uc());
159  return ad::boost::CovInvTransformNode(uc_view,
160  curr_pack.c_val,
161  curr_pack.c_val + size_c(),
162  rows_c(),
163  curr_pack.v_val,
164  refcnt);
165  }
166 
173  template <class CurrPtrPack, class PtrPack>
174  auto logj_inv_transform_ad(const CurrPtrPack& curr_pack,
175  const PtrPack&) const {
176  ad::VarView<value_t, ad::vec> uc_view(curr_pack.uc_val,
177  curr_pack.uc_adj,
178  size_uc());
179  return ad::boost::LogJCovInvTransformNode(uc_view, rows_c());
180  }
181 
187  template <class GenType, class ContDist>
188  void init(GenType&, ContDist&) {
189  uc_val_.setZero();
190  }
191 
192  void activate_refcnt(size_t) const {}
193 
194  var_t& get_c() { return util::get(c_val_); }
195  const var_t& get_c() const { return util::get(c_val_); }
196 
201  constexpr size_t size_uc() const { return uc_val_.size(); }
202  constexpr size_t rows_uc() const { return uc_val_.rows(); }
203  constexpr size_t cols_uc() const { return 1; }
204  constexpr size_t size_c() const { return c_val_.size(); }
205  constexpr size_t rows_c() const { return c_val_.rows(); }
206  constexpr size_t cols_c() const { return c_val_.cols(); }
207 
212  constexpr size_t bind_size_uc() const { return size_uc(); }
213  constexpr size_t bind_size_c() const { return 2*size_c(); }
214  constexpr size_t bind_size_v() const { return 1; }
215 
221  template <class CurrPtrPack, class PtrPack>
222  void bind(const CurrPtrPack& curr_pack,
223  const PtrPack&)
224  {
225  util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
226  util::bind(lower_, curr_pack.c_val, rows_c(), cols_c());
227  util::bind(c_val_, curr_pack.c_val + size_c(), rows_c(), cols_c());
228  util::bind(v_val_, curr_pack.v_val, 1, 1);
229  }
230 
231 private:
232  uc_view_t uc_val_;
233  view_t c_val_;
234  view_t lower_; // temporary lower-triangular matrix
235  size_t* v_val_;
236 };
237 
238 } // namespace constraint
239 } // namespace expr
240 
241 constexpr inline auto pos_def()
242 {
243  return expr::constraint::PosDef();
244 }
245 
246 } // namespace ppl
ppl::expr::constraint::Transformer
Definition: transformer.hpp:10
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::var_t
util::var_t< value_t, shape_t > var_t
Definition: pos_def.hpp:65
cov_inv_transform.hpp
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::bind_size_c
constexpr size_t bind_size_c() const
Definition: pos_def.hpp:213
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::bind_size_v
constexpr size_t bind_size_v() const
Definition: pos_def.hpp:214
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::get_c
var_t & get_c()
Definition: pos_def.hpp:194
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::transform
void transform()
Definition: pos_def.hpp:89
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::bind_size_uc
constexpr size_t bind_size_uc() const
Definition: pos_def.hpp:212
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::inv_transform
void inv_transform(size_t refcnt)
Definition: pos_def.hpp:101
ad::boost::CovInvTransformNode
Definition: cov_inv_transform.hpp:31
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::shape_t
mat shape_t
Definition: pos_def.hpp:64
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::logj_inv_transform_ad
auto logj_inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &) const
Definition: pos_def.hpp:174
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::Transformer
Transformer(size_t rows, size_t, constraint_t=constraint_t())
Definition: pos_def.hpp:80
ppl::expr::constraint::PosDef::size
static constexpr size_t size(size_t rows)
Definition: pos_def.hpp:21
ppl::lower
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::value_t
ValueType value_t
Definition: pos_def.hpp:63
ppl::pos_def
constexpr auto pos_def()
Definition: pos_def.hpp:241
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::view_t
ad::util::shape_to_raw_view_t< value_t, shape_t > view_t
Definition: pos_def.hpp:68
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::size_c
constexpr size_t size_c() const
Definition: pos_def.hpp:204
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::inv_transform_ad
auto inv_transform_ad(const CurrPtrPack &curr_pack, const PtrPack &, size_t refcnt) const
Definition: pos_def.hpp:153
ad::boost::LogJCovInvTransformNode
Definition: cov_inv_transform.hpp:126
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::rows_c
constexpr size_t rows_c() const
Definition: pos_def.hpp:205
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
value.hpp
ppl::expr::constraint::PosDef::inv_transform
static constexpr void inv_transform(LowerType &lower, const UCType &uc, CType &c)
Definition: pos_def.hpp:53
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::cols_c
constexpr size_t cols_c() const
Definition: pos_def.hpp:206
ppl::expr::constraint::PosDef
Definition: pos_def.hpp:15
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::init
void init(GenType &, ContDist &)
Definition: pos_def.hpp:188
ppl
Definition: bounded.hpp:11
ppl::mat
ad::mat mat
Definition: shape_traits.hpp:18
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::uc_view_t
ad::util::shape_to_raw_view_t< value_t, vec > uc_view_t
Definition: pos_def.hpp:67
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::rows_uc
constexpr size_t rows_uc() const
Definition: pos_def.hpp:202
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::size_uc
constexpr size_t size_uc() const
Definition: pos_def.hpp:201
ad::boost::cov_inv_transform
constexpr void cov_inv_transform(LowerType &lower, const UCType &uc, CType &c)
Definition: cov_inv_transform.hpp:12
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::cols_uc
constexpr size_t cols_uc() const
Definition: pos_def.hpp:203
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::activate_refcnt
void activate_refcnt(size_t) const
Definition: pos_def.hpp:192
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::get_c
const var_t & get_c() const
Definition: pos_def.hpp:195
transformer.hpp
ppl::expr::constraint::Transformer< ValueType, mat, PosDef >::bind
void bind(const CurrPtrPack &curr_pack, const PtrPack &)
Definition: pos_def.hpp:222
ppl::expr::constraint::PosDef::transform
static constexpr void transform(const CType &c, UCType &uc)
Definition: pos_def.hpp:28