autoppl  v0.8
A C++ template library for probabilistic programming
cov_inv_transform.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/expr_base.hpp>
3 #include <fastad_bits/reverse/core/value_adj_view.hpp>
4 #include <fastad_bits/util/type_traits.hpp>
5 #include <fastad_bits/util/size_pack.hpp>
6 #include <fastad_bits/util/value.hpp>
7 
8 namespace ad {
9 namespace boost {
10 
11 template <class LowerType, class UCType, class CType>
12 inline constexpr void cov_inv_transform(LowerType& lower,
13  const UCType& uc,
14  CType& c)
15 {
16  size_t k = 0;
17  for (int j = 0; j < lower.cols(); ++j) {
18  lower(j,j) = std::exp(uc(k));
19  ++k;
20  for (int i = j+1; i < lower.rows(); ++i, ++k) {
21  lower(i,j) = uc(k);
22  }
23  }
24  c = lower * lower.transpose();
25 }
26 
27 template <class ExprType>
29  core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::mat>,
30  core::ExprBase<CovInvTransformNode<ExprType>>
31 {
32 private:
33  using expr_t = ExprType;
34  using expr_value_t = typename util::expr_traits<expr_t>::value_t;
35  using expr_shape_t = typename util::shape_traits<expr_t>::shape_t;
36 
37  static_assert(util::is_vec_v<expr_t>);
38 
39 public:
40  using value_adj_view_t = core::ValueAdjView<expr_value_t, ad::mat>;
41  using typename value_adj_view_t::value_t;
42  using typename value_adj_view_t::shape_t;
43  using typename value_adj_view_t::var_t;
44  using typename value_adj_view_t::ptr_pack_t;
45 
46  CovInvTransformNode(const expr_t& expr,
47  value_t* lower,
48  value_t* val,
49  size_t rows,
50  size_t* visit_cnt,
51  size_t refcnt)
52  : value_adj_view_t(val, nullptr, rows, rows)
53  , expr_{expr}
54  , lower_(lower, rows, rows)
55  , adj_()
56  , flattened_adj_(expr.size())
57  , v_val_{visit_cnt}
58  , refcnt_{refcnt}
59  {}
60 
61  const var_t& feval()
62  {
63  ++*v_val_;
64  if (*v_val_ == 1) {
65  auto&& uc_val_ = expr_.feval();
66  cov_inv_transform(lower_, uc_val_, this->get());
67  }
68  *v_val_ = *v_val_ % refcnt_;
69  return this->get();
70  }
71 
72  template <class T>
73  void beval(const T& seed)
74  {
75  auto&& a_adj = util::to_array(this->get_adj());
76  auto&& a_flattened_adj = util::to_array(flattened_adj_);
77 
78  a_adj = seed;
79  adj_ = (this->get_adj().transpose() + this->get_adj()) * lower_;
80  adj_.diagonal().array() *= lower_.diagonal().array();
81 
82  size_t k = 0;
83  for (size_t j = 0; j < this->cols(); ++j) {
84  for (size_t i = j; i < this->rows(); ++i, ++k) {
85  flattened_adj_(k) = adj_(i,j);
86  }
87  }
88 
89  expr_.beval(a_flattened_adj);
90  }
91 
92  ptr_pack_t bind_cache(ptr_pack_t begin)
93  {
94  begin = expr_.bind_cache(begin);
95  auto val = begin.val;
96  begin.val = this->data();
97  begin = this->bind(begin);
98  begin.val = val;
99  return begin;
100  }
101 
102  util::SizePack bind_cache_size() const
103  {
104  return single_bind_cache_size() +
105  expr_.bind_cache_size();
106  }
107 
108  util::SizePack single_bind_cache_size() const {
109  return {0, this->size()};
110  }
111 
112 private:
113  using mat_view_t = util::shape_to_raw_view_t<value_t, shape_t>;
114  expr_t expr_;
115  mat_view_t lower_;
116  util::constant_var_t<value_t, shape_t> adj_;
117  util::constant_var_t<value_t, expr_shape_t> flattened_adj_;
118  size_t* v_val_;
119  size_t const refcnt_;
120 };
121 
122 template <class ExprType>
124  core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::scl>,
125  core::ExprBase<LogJCovInvTransformNode<ExprType>>
126 {
127 private:
128  using expr_t = ExprType;
129  using expr_value_t = typename util::expr_traits<expr_t>::value_t;
130  using expr_shape_t = typename util::shape_traits<expr_t>::shape_t;
131 
132  static_assert(util::is_vec_v<expr_t>);
133 
134 public:
135  using value_adj_view_t = core::ValueAdjView<expr_value_t, ad::scl>;
136  using typename value_adj_view_t::value_t;
137  using typename value_adj_view_t::shape_t;
138  using typename value_adj_view_t::var_t;
139  using typename value_adj_view_t::ptr_pack_t;
140 
141  LogJCovInvTransformNode(const expr_t& expr,
142  size_t rows)
143  : value_adj_view_t(nullptr, nullptr, 1, 1)
144  , expr_{expr}
145  , rows_{rows}
146  , flattened_adj_(expr.size())
147  {}
148 
149  const var_t& feval()
150  {
151  auto&& expr = expr_.feval();
152  size_t weight = rows_ + 1;
153  size_t incr = rows_;
154  size_t pos = 0;
155  this->zero(); // REALLY important
156  for (size_t k = 0; k < rows_; ++k, --weight, --incr) {
157  this->get() += weight * expr(pos);
158  pos += incr;
159  }
160  return this->get();
161  }
162 
163  void beval(value_t seed)
164  {
165  size_t weight = rows_ + 1;
166  size_t incr = rows_;
167  size_t pos = 0;
168  flattened_adj_.setZero();
169  for (size_t k = 0; k < rows_; ++k, --weight, --incr) {
170  flattened_adj_(pos) = weight;
171  pos += incr;
172  }
173  expr_.beval(seed * flattened_adj_.array());
174  }
175 
176  ptr_pack_t bind_cache(ptr_pack_t begin)
177  {
178  begin = expr_.bind_cache(begin);
179  auto adj = begin.adj;
180  begin.adj = nullptr;
181  begin = this->bind(begin);
182  begin.adj = adj;
183  return begin;
184  }
185 
186  util::SizePack bind_cache_size() const
187  {
188  return single_bind_cache_size() +
189  expr_.bind_cache_size();
190  }
191 
192  util::SizePack single_bind_cache_size() const {
193  return {this->size(), 0};
194  }
195 
196 private:
197  expr_t expr_;
198  size_t const rows_;
199  util::constant_var_t<value_t, expr_shape_t> flattened_adj_;
200 };
201 
202 } // namespace boost
203 } // namespace ad
ad::boost::CovInvTransformNode::beval
void beval(const T &seed)
Definition: cov_inv_transform.hpp:73
ad
Definition: bounded_inv_transform.hpp:9
ad::boost::LogJCovInvTransformNode::beval
void beval(value_t seed)
Definition: cov_inv_transform.hpp:163
ad::boost::LogJCovInvTransformNode::feval
const var_t & feval()
Definition: cov_inv_transform.hpp:149
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
ad::boost::CovInvTransformNode::value_adj_view_t
core::ValueAdjView< expr_value_t, ad::mat > value_adj_view_t
Definition: cov_inv_transform.hpp:40
ad::boost::CovInvTransformNode::single_bind_cache_size
util::SizePack single_bind_cache_size() const
Definition: cov_inv_transform.hpp:108
ad::boost::CovInvTransformNode
Definition: cov_inv_transform.hpp:31
ad::boost::LogJCovInvTransformNode::single_bind_cache_size
util::SizePack single_bind_cache_size() const
Definition: cov_inv_transform.hpp:192
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ad::boost::CovInvTransformNode::bind_cache_size
util::SizePack bind_cache_size() const
Definition: cov_inv_transform.hpp:102
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ppl::lower
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ad::boost::CovInvTransformNode::feval
const var_t & feval()
Definition: cov_inv_transform.hpp:61
ad::boost::LogJCovInvTransformNode
Definition: cov_inv_transform.hpp:126
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ad::boost::LogJCovInvTransformNode::LogJCovInvTransformNode
LogJCovInvTransformNode(const expr_t &expr, size_t rows)
Definition: cov_inv_transform.hpp:141
ad::boost::CovInvTransformNode::bind_cache
ptr_pack_t bind_cache(ptr_pack_t begin)
Definition: cov_inv_transform.hpp:92
ad::boost::LogJCovInvTransformNode::bind_cache
ptr_pack_t bind_cache(ptr_pack_t begin)
Definition: cov_inv_transform.hpp:176
ppl::util::to_array
constexpr auto to_array(const T &x)
Definition: value.hpp:74
ad::boost::LogJCovInvTransformNode::value_adj_view_t
core::ValueAdjView< expr_value_t, ad::scl > value_adj_view_t
Definition: cov_inv_transform.hpp:135
ad::boost::LogJCovInvTransformNode::bind_cache_size
util::SizePack bind_cache_size() const
Definition: cov_inv_transform.hpp:186
ad::boost::cov_inv_transform
constexpr void cov_inv_transform(LowerType &lower, const UCType &uc, CType &c)
Definition: cov_inv_transform.hpp:12
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ad::boost::CovInvTransformNode::CovInvTransformNode
CovInvTransformNode(const expr_t &expr, value_t *lower, value_t *val, size_t rows, size_t *visit_cnt, size_t refcnt)
Definition: cov_inv_transform.hpp:46