autoppl  v0.8
A C++ template library for probabilistic programming
lower_inv_transform.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/expr_base.hpp>
3 #include <fastad_bits/reverse/core/value_adj_view.hpp>
4 #include <fastad_bits/util/type_traits.hpp>
5 #include <fastad_bits/util/size_pack.hpp>
6 #include <fastad_bits/util/value.hpp>
8 
9 namespace ad {
10 namespace boost {
11 
12 template <class UCType, class LowerType, class CType>
13 inline constexpr void lower_inv_transform(const UCType& uc,
14  const LowerType& lower,
15  CType& c)
16 {
17  using uc_t = UCType;
18  using c_t = std::decay_t<CType>;
19  if constexpr (std::is_arithmetic_v<uc_t> &&
20  std::is_arithmetic_v<c_t>) {
21  c = std::exp(uc) + lower;
22  } else {
23  if constexpr (std::is_arithmetic_v<LowerType>) {
24  c = (uc.array().exp() + lower).matrix();
25  } else {
26  c = (uc.array().exp() + lower.array()).matrix();
27  }
28  }
29 }
30 
31 template <class ExprType
32  , class LowerType>
34  core::ValueAdjView<typename util::expr_traits<ExprType>::value_t,
35  typename util::expr_traits<ExprType>::shape_t>,
36  core::ExprBase<LowerInvTransformNode<ExprType, LowerType>>
37 {
38 private:
39  using expr_t = ExprType;
40  using expr_value_t = typename util::expr_traits<expr_t>::value_t;
41  using expr_shape_t = typename util::shape_traits<expr_t>::shape_t;
42  using lower_t = LowerType;
43 
44  static_assert(util::is_scl_v<lower_t> ||
45  std::is_same_v<
46  typename util::shape_traits<expr_t>::shape_t,
47  typename util::shape_traits<lower_t>::shape_t
48  >);
49 
50 public:
51  using value_adj_view_t = core::ValueAdjView<expr_value_t, expr_shape_t>;
52  using typename value_adj_view_t::value_t;
53  using typename value_adj_view_t::shape_t;
54  using typename value_adj_view_t::var_t;
55  using typename value_adj_view_t::ptr_pack_t;
56 
57  LowerInvTransformNode(const expr_t& expr,
58  const lower_t& lower,
59  value_t* c_val,
60  size_t* visit_cnt,
61  size_t refcnt)
62  : value_adj_view_t(c_val, nullptr, expr.rows(), expr.cols())
63  , expr_{expr}
64  , lower_{lower}
65  , v_val_{visit_cnt}
66  , refcnt_{refcnt}
67  {}
68 
69  const var_t& feval()
70  {
71  auto&& lower = lower_.feval();
72  ++*v_val_;
73  if (*v_val_ == 1) {
74  auto&& uc_val = expr_.feval();
75  lower_inv_transform(uc_val, lower, this->get());
76  }
77  *v_val_ = *v_val_ % refcnt_;
78  return this->get();
79  }
80 
81  template <class T>
82  void beval(const T& seed)
83  {
84  auto&& a_val = util::to_array(this->get());
85  auto&& a_adj = util::to_array(this->get_adj());
86  auto&& a_lower = util::to_array(lower_.get());
87  a_adj = seed;
88  if constexpr (util::is_scl_v<lower_t>) {
89  lower_.beval(sum(a_adj));
90  } else {
91  lower_.beval(a_adj);
92  }
93  expr_.beval(a_adj * (a_val - a_lower));
94  }
95 
96  ptr_pack_t bind_cache(ptr_pack_t begin)
97  {
98  begin = expr_.bind_cache(begin);
99  begin = lower_.bind_cache(begin);
100  auto val = begin.val;
101  begin.val = this->data();
102  begin = this->bind(begin);
103  begin.val = val;
104  return begin;
105  }
106 
107  util::SizePack bind_cache_size() const
108  {
109  return single_bind_cache_size() +
110  expr_.bind_cache_size() +
111  lower_.bind_cache_size();
112  }
113 
114  util::SizePack single_bind_cache_size() const {
115  return {0,this->size()};
116  }
117 
118 private:
119  expr_t expr_;
120  lower_t lower_;
121  size_t* v_val_;
122  size_t const refcnt_;
123 };
124 
125 } // namespace boost
126 } // namespace ad
ad
Definition: bounded_inv_transform.hpp:9
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
value.hpp
ad::boost::sum
auto sum(const T &x)
Definition: value.hpp:8
ad::boost::LowerInvTransformNode::feval
const var_t & feval()
Definition: lower_inv_transform.hpp:69
ad::boost::LowerInvTransformNode::value_adj_view_t
core::ValueAdjView< expr_value_t, expr_shape_t > value_adj_view_t
Definition: lower_inv_transform.hpp:51
ppl::util::bind
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
ppl::util::cols
constexpr size_t cols(const T &x)
Definition: value.hpp:42
ad::boost::LowerInvTransformNode::bind_cache_size
util::SizePack bind_cache_size() const
Definition: lower_inv_transform.hpp:107
ppl::lower
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
ppl::util::size
constexpr size_t size(const T &x)
Definition: value.hpp:22
ad::boost::LowerInvTransformNode::beval
void beval(const T &seed)
Definition: lower_inv_transform.hpp:82
ppl::util::var_t
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ad::boost::LowerInvTransformNode::bind_cache
ptr_pack_t bind_cache(ptr_pack_t begin)
Definition: lower_inv_transform.hpp:96
ad::boost::LowerInvTransformNode::single_bind_cache_size
util::SizePack single_bind_cache_size() const
Definition: lower_inv_transform.hpp:114
ad::boost::LowerInvTransformNode
Definition: lower_inv_transform.hpp:37
ppl::util::to_array
constexpr auto to_array(const T &x)
Definition: value.hpp:74
ad::boost::lower_inv_transform
constexpr void lower_inv_transform(const UCType &uc, const LowerType &lower, CType &c)
Definition: lower_inv_transform.hpp:13
ad::boost::LowerInvTransformNode::LowerInvTransformNode
LowerInvTransformNode(const expr_t &expr, const lower_t &lower, value_t *c_val, size_t *visit_cnt, size_t refcnt)
Definition: lower_inv_transform.hpp:57
ppl::util::rows
constexpr size_t rows(const T &x)
Definition: value.hpp:32