autoppl  v0.8
A C++ template library for probabilistic programming
bounded_inv_transform.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/expr_base.hpp>
3 #include <fastad_bits/reverse/core/value_adj_view.hpp>
4 #include <fastad_bits/util/type_traits.hpp>
5 #include <fastad_bits/util/size_pack.hpp>
6 #include <fastad_bits/util/value.hpp>
8 
9 namespace ad {
10 namespace boost {
11 
12 template <class UCType
13  , class LowerType
14  , class UpperType
15  , class CType>
16 inline constexpr void bounded_inv_transform(const UCType& uc,
17  const LowerType& lower,
18  const UpperType& upper,
19  CType& c)
20 {
21  using std::exp;
22  using Eigen::exp;
23  auto auc = util::to_array(uc);
24  auto alower = util::to_array(lower);
25  auto aupper = util::to_array(upper);
26  c = alower + (aupper - alower) / (1. + exp(-auc));
27 }
28 
29 template <class ExprType
30  , class LowerType
31  , class UpperType>
33  core::ValueAdjView<typename util::expr_traits<ExprType>::value_t,
34  typename util::shape_traits<ExprType>::shape_t>,
35  core::ExprBase<BoundedInvTransformNode<ExprType, LowerType, UpperType>>
36 {
37 private:
38  using expr_t = ExprType;
39  using expr_value_t = typename util::expr_traits<expr_t>::value_t;
40  using expr_shape_t = typename util::expr_traits<expr_t>::shape_t;
41  using lower_t = LowerType;
42  using upper_t = UpperType;
43 
44 public:
45  using value_adj_view_t = core::ValueAdjView<expr_value_t, expr_shape_t>;
46  using typename value_adj_view_t::value_t;
47  using typename value_adj_view_t::shape_t;
48  using typename value_adj_view_t::var_t;
49  using typename value_adj_view_t::ptr_pack_t;
50 
51  BoundedInvTransformNode(const expr_t& expr,
52  const lower_t& lower,
53  const upper_t& upper,
54  value_t* c_val,
55  size_t* visit_cnt,
56  size_t refcnt)
57  : value_adj_view_t(c_val, nullptr, expr.rows(), expr.cols())
58  , expr_{expr}
59  , lower_{lower}
60  , upper_{upper}
61  , scaled_inv_logit_()
62  , inv_logit_()
63  , v_val_{visit_cnt}
64  , refcnt_{refcnt}
65  {}
66 
67  const var_t& feval()
68  {
69  auto&& lower = lower_.feval();
70  auto&& upper = upper_.feval();
71  ++*v_val_;
72  if (*v_val_ == 1) {
73  auto&& uc_val = expr_.feval();
74  bounded_inv_transform(uc_val, lower, upper, this->get());
75  }
76  *v_val_ = *v_val_ % refcnt_;
77  return this->get();
78  }
79 
80  template <class T>
81  void beval(const T& seed)
82  {
83  auto&& a_val = util::to_array(this->get());
84  auto&& a_adj = util::to_array(this->get_adj());
85  auto&& a_lower = util::to_array(lower_.get());
86  auto&& a_upper = util::to_array(upper_.get());
87 
88  a_adj = seed;
89 
90  scaled_inv_logit_ = (a_val - a_lower);
91  auto&& a_scaled_inv_logit = util::to_array(scaled_inv_logit_);
92 
93  inv_logit_ = a_scaled_inv_logit / (a_upper - a_lower);
94  auto&& a_inv_logit = util::to_array(inv_logit_);
95 
96  if constexpr (util::is_scl_v<upper_t>) {
97  upper_.beval(sum(a_adj * a_inv_logit));
98  } else {
99  upper_.beval(a_adj * a_inv_logit);
100  }
101 
102  if constexpr (util::is_scl_v<lower_t>) {
103  lower_.beval(sum(a_adj * (1. - a_inv_logit)));
104  } else {
105  lower_.beval(a_adj * (1. - a_inv_logit));
106  }
107 
108  expr_.beval(a_adj * a_scaled_inv_logit * (1. - a_inv_logit));
109  }
110 
111  ptr_pack_t bind_cache(ptr_pack_t begin)
112  {
113  begin = expr_.bind_cache(begin);
114  begin = lower_.bind_cache(begin);
115  begin = upper_.bind_cache(begin);
116  auto val = begin.val;
117  begin.val = this->data();
118  begin = this->bind(begin);
119  begin.val = val;
120  return begin;
121  }
122 
123  util::SizePack bind_cache_size() const
124  {
125  return single_bind_cache_size() +
126  expr_.bind_cache_size() +
127  lower_.bind_cache_size() +
128  upper_.bind_cache_size();
129  }
130 
131  util::SizePack single_bind_cache_size() const {
132  return {0, this->size()};
133  }
134 
135 private:
136  expr_t expr_;
137  lower_t lower_;
138  upper_t upper_;
139  util::constant_var_t<value_t, shape_t> scaled_inv_logit_;
140  util::constant_var_t<value_t, shape_t> inv_logit_;
141  size_t* v_val_;
142  size_t const refcnt_;
143 };
144 
145 template <class ExprType
146  , class LowerType
147  , class UpperType>
149  core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::scl>,
150  core::ExprBase<LogJBoundedInvTransformNode<ExprType, LowerType, UpperType>>
151 {
152 private:
153  using expr_t = ExprType;
154  using lower_t = LowerType;
155  using upper_t = UpperType;
156  using expr_value_t = typename util::expr_traits<expr_t>::value_t;
157  using expr_shape_t = typename util::shape_traits<expr_t>::shape_t;
158  using lower_shape_t = typename util::shape_traits<lower_t>::shape_t;
159  using upper_shape_t = typename util::shape_traits<upper_t>::shape_t;
160 
161 public:
162  using value_adj_view_t = core::ValueAdjView<expr_value_t, ad::scl>;
163  using typename value_adj_view_t::value_t;
164  using typename value_adj_view_t::shape_t;
165  using typename value_adj_view_t::var_t;
166  using typename value_adj_view_t::ptr_pack_t;
167 
168  LogJBoundedInvTransformNode(const expr_t& expr,
169  const lower_t& lower,
170  const upper_t& upper,
171  value_t* c_val)
172  : value_adj_view_t(nullptr, nullptr, 1, 1)
173  , expr_{expr}
174  , lower_{lower}
175  , upper_{upper}
176  , c_val_(c_val, expr.rows(), expr.cols())
177  , scaled_inv_logit_()
178  {}
179 
180  const var_t& feval()
181  {
182  using std::log;
183  using Eigen::log;
184 
185  expr_.feval();
186  auto&& a_c_val = util::to_array(c_val_.get());
187  auto&& a_lower = util::to_array(lower_.feval());
188  auto&& a_upper = util::to_array(upper_.feval());
189 
190  scaled_inv_logit_ = a_c_val - a_lower;
191  inv_range_ = 1. / (a_upper - a_lower);
192 
193  auto&& a_scaled_inv_logit = util::to_array(scaled_inv_logit_);
194 
195  // this may be slightly inefficient since this requires extra multiplication,
196  // but the benefit is that we can save inverse range, which gets reused a lot in beval
197  auto&& a_inv_logit = a_scaled_inv_logit * util::to_array(inv_range_);
198 
199  return this->get() = sum(log(a_scaled_inv_logit * (1. - a_inv_logit)));
200  }
201 
202  void beval(value_t seed)
203  {
204  auto a_inv_range = util::to_array(inv_range_);
205  auto a_scaled_inv_logit = util::to_array(scaled_inv_logit_);
206 
207  if constexpr (util::is_scl_v<upper_t> &&
208  util::is_scl_v<lower_t>) {
209  upper_.beval(seed * expr_.size() * a_inv_range);
210  lower_.beval(seed * expr_.size() * (-a_inv_range));
211 
212  } else if constexpr (util::is_scl_v<upper_t>) {
213  upper_.beval(seed * a_inv_range.sum());
214  lower_.beval(seed * (-a_inv_range));
215 
216  } else if constexpr (util::is_scl_v<lower_t>) {
217  upper_.beval(seed * a_inv_range);
218  lower_.beval(seed * (-a_inv_range.sum()));
219 
220  } else {
221  assert(upper_.cols() == lower_.cols());
222  assert(upper_.rows() == lower_.rows());
223  upper_.beval(seed * a_inv_range);
224  lower_.beval(seed * -a_inv_range);
225  }
226 
227  expr_.beval(seed * (1. - 2. * a_scaled_inv_logit * a_inv_range));
228  }
229 
230  ptr_pack_t bind_cache(ptr_pack_t begin)
231  {
232  begin = expr_.bind_cache(begin);
233  begin = lower_.bind_cache(begin);
234  begin = upper_.bind_cache(begin);
235  auto adj = begin.adj;
236  begin.adj = nullptr;
237  begin = this->bind(begin);
238  begin.adj = adj;
239  return begin;
240  }
241 
242  util::SizePack bind_cache_size() const
243  {
244  return single_bind_cache_size() +
245  expr_.bind_cache_size() +
246  lower_.bind_cache_size() +
247  upper_.bind_cache_size();
248  }
249 
250  util::SizePack single_bind_cache_size() const {
251  return {this->size(), 0};
252  }
253 
254 private:
255  using view_t = core::ValueView<expr_value_t, expr_shape_t>;
256  expr_t expr_;
257  lower_t lower_;
258  upper_t upper_;
259  view_t c_val_;
260  util::constant_var_t<expr_value_t, expr_shape_t> scaled_inv_logit_;
261  util::constant_var_t<expr_value_t,
262  util::max_shape_t<lower_shape_t, upper_shape_t> > inv_range_;
263 };
264 
265 } // namespace boost
266 } // namespace ad
ad::boost::BoundedInvTransformNode::beval
void beval(const T &seed)
Definition: bounded_inv_transform.hpp:81
ad
Definition: bounded_inv_transform.hpp:9
ad::boost::BoundedInvTransformNode::single_bind_cache_size
util::SizePack single_bind_cache_size() const
Definition: bounded_inv_transform.hpp:131
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
value.hpp
ad::boost::bounded_inv_transform
constexpr void bounded_inv_transform(const UCType &uc, const LowerType &lower, const UpperType &upper, CType &c)
Definition: bounded_inv_transform.hpp:16
ad::boost::sum
auto sum(const T &x)
Definition: value.hpp:8
ad::boost::BoundedInvTransformNode::BoundedInvTransformNode
BoundedInvTransformNode(const expr_t &expr, const lower_t &lower, const upper_t &upper, value_t *c_val, size_t *visit_cnt, size_t refcnt)
Definition: bounded_inv_transform.hpp:51
ad::boost::BoundedInvTransformNode::bind_cache_size
util::SizePack bind_cache_size() const
Definition: bounded_inv_transform.hpp:123
ad::boost::LogJBoundedInvTransformNode::beval
void beval(value_t seed)
Definition: bounded_inv_transform.hpp:202
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ad::boost::BoundedInvTransformNode::value_adj_view_t
core::ValueAdjView< expr_value_t, expr_shape_t > value_adj_view_t
Definition: bounded_inv_transform.hpp:45
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ad::boost::LogJBoundedInvTransformNode::bind_cache_size
util::SizePack bind_cache_size() const
Definition: bounded_inv_transform.hpp:242
ppl::lower
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ad::boost::BoundedInvTransformNode
Definition: bounded_inv_transform.hpp:36
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ad::boost::LogJBoundedInvTransformNode::feval
const var_t & feval()
Definition: bounded_inv_transform.hpp:180
ad::boost::BoundedInvTransformNode::feval
const var_t & feval()
Definition: bounded_inv_transform.hpp:67
ad::boost::LogJBoundedInvTransformNode::bind_cache
ptr_pack_t bind_cache(ptr_pack_t begin)
Definition: bounded_inv_transform.hpp:230
ad::boost::BoundedInvTransformNode::bind_cache
ptr_pack_t bind_cache(ptr_pack_t begin)
Definition: bounded_inv_transform.hpp:111
ad::boost::LogJBoundedInvTransformNode::LogJBoundedInvTransformNode
LogJBoundedInvTransformNode(const expr_t &expr, const lower_t &lower, const upper_t &upper, value_t *c_val)
Definition: bounded_inv_transform.hpp:168
ad::boost::LogJBoundedInvTransformNode::single_bind_cache_size
util::SizePack single_bind_cache_size() const
Definition: bounded_inv_transform.hpp:250
ppl::util::to_array
constexpr auto to_array(const T &x)
Definition: value.hpp:74
ad::boost::LogJBoundedInvTransformNode::value_adj_view_t
core::ValueAdjView< expr_value_t, ad::scl > value_adj_view_t
Definition: bounded_inv_transform.hpp:162
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32
ad::boost::LogJBoundedInvTransformNode
Definition: bounded_inv_transform.hpp:151