2 #include <fastad_bits/reverse/core/expr_base.hpp>
3 #include <fastad_bits/reverse/core/value_adj_view.hpp>
4 #include <fastad_bits/util/type_traits.hpp>
5 #include <fastad_bits/util/size_pack.hpp>
6 #include <fastad_bits/util/value.hpp>
12 template <
class UCType
17 const LowerType&
lower,
18 const UpperType& upper,
26 c = alower + (aupper - alower) / (1. + exp(-auc));
29 template <
class ExprType
33 core::ValueAdjView<typename util::expr_traits<ExprType>::value_t,
34 typename util::shape_traits<ExprType>::shape_t>,
35 core::ExprBase<BoundedInvTransformNode<ExprType, LowerType, UpperType>>
38 using expr_t = ExprType;
39 using expr_value_t =
typename util::expr_traits<expr_t>::value_t;
40 using expr_shape_t =
typename util::expr_traits<expr_t>::shape_t;
41 using lower_t = LowerType;
42 using upper_t = UpperType;
46 using typename value_adj_view_t::value_t;
47 using typename value_adj_view_t::shape_t;
49 using typename value_adj_view_t::ptr_pack_t;
69 auto&&
lower = lower_.feval();
70 auto&& upper = upper_.feval();
73 auto&& uc_val = expr_.feval();
76 *v_val_ = *v_val_ % refcnt_;
90 scaled_inv_logit_ = (a_val - a_lower);
93 inv_logit_ = a_scaled_inv_logit / (a_upper - a_lower);
96 if constexpr (util::is_scl_v<upper_t>) {
97 upper_.beval(
sum(a_adj * a_inv_logit));
99 upper_.beval(a_adj * a_inv_logit);
102 if constexpr (util::is_scl_v<lower_t>) {
103 lower_.beval(
sum(a_adj * (1. - a_inv_logit)));
105 lower_.beval(a_adj * (1. - a_inv_logit));
108 expr_.beval(a_adj * a_scaled_inv_logit * (1. - a_inv_logit));
113 begin = expr_.bind_cache(begin);
114 begin = lower_.bind_cache(begin);
115 begin = upper_.bind_cache(begin);
116 auto val = begin.val;
117 begin.val = this->data();
118 begin = this->
bind(begin);
126 expr_.bind_cache_size() +
127 lower_.bind_cache_size() +
128 upper_.bind_cache_size();
132 return {0, this->
size()};
139 util::constant_var_t<value_t, shape_t> scaled_inv_logit_;
140 util::constant_var_t<value_t, shape_t> inv_logit_;
142 size_t const refcnt_;
145 template <
class ExprType
149 core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::scl>,
150 core::ExprBase<LogJBoundedInvTransformNode<ExprType, LowerType, UpperType>>
153 using expr_t = ExprType;
154 using lower_t = LowerType;
155 using upper_t = UpperType;
156 using expr_value_t =
typename util::expr_traits<expr_t>::value_t;
157 using expr_shape_t =
typename util::shape_traits<expr_t>::shape_t;
158 using lower_shape_t =
typename util::shape_traits<lower_t>::shape_t;
159 using upper_shape_t =
typename util::shape_traits<upper_t>::shape_t;
163 using typename value_adj_view_t::value_t;
164 using typename value_adj_view_t::shape_t;
166 using typename value_adj_view_t::ptr_pack_t;
169 const lower_t&
lower,
170 const upper_t& upper,
176 , c_val_(c_val, expr.rows(), expr.cols())
177 , scaled_inv_logit_()
190 scaled_inv_logit_ = a_c_val - a_lower;
191 inv_range_ = 1. / (a_upper - a_lower);
197 auto&& a_inv_logit = a_scaled_inv_logit *
util::to_array(inv_range_);
199 return this->
get() =
sum(log(a_scaled_inv_logit * (1. - a_inv_logit)));
207 if constexpr (util::is_scl_v<upper_t> &&
208 util::is_scl_v<lower_t>) {
209 upper_.beval(seed * expr_.size() * a_inv_range);
210 lower_.beval(seed * expr_.size() * (-a_inv_range));
212 }
else if constexpr (util::is_scl_v<upper_t>) {
213 upper_.beval(seed * a_inv_range.sum());
214 lower_.beval(seed * (-a_inv_range));
216 }
else if constexpr (util::is_scl_v<lower_t>) {
217 upper_.beval(seed * a_inv_range);
218 lower_.beval(seed * (-a_inv_range.sum()));
221 assert(upper_.cols() == lower_.cols());
222 assert(upper_.rows() == lower_.rows());
223 upper_.beval(seed * a_inv_range);
224 lower_.beval(seed * -a_inv_range);
227 expr_.beval(seed * (1. - 2. * a_scaled_inv_logit * a_inv_range));
232 begin = expr_.bind_cache(begin);
233 begin = lower_.bind_cache(begin);
234 begin = upper_.bind_cache(begin);
235 auto adj = begin.adj;
237 begin = this->
bind(begin);
245 expr_.bind_cache_size() +
246 lower_.bind_cache_size() +
247 upper_.bind_cache_size();
251 return {this->
size(), 0};
255 using view_t = core::ValueView<expr_value_t, expr_shape_t>;
260 util::constant_var_t<expr_value_t, expr_shape_t> scaled_inv_logit_;
261 util::constant_var_t<expr_value_t,
262 util::max_shape_t<lower_shape_t, upper_shape_t> > inv_range_;