2 #include <fastad_bits/reverse/core/expr_base.hpp> 
    3 #include <fastad_bits/reverse/core/value_adj_view.hpp> 
    4 #include <fastad_bits/util/type_traits.hpp> 
    5 #include <fastad_bits/util/size_pack.hpp> 
    6 #include <fastad_bits/util/value.hpp> 
   12 template <
class UCType
 
   17                                             const LowerType& 
lower,
 
   18                                             const UpperType& upper,
 
   26     c = alower + (aupper - alower) / (1. + exp(-auc)); 
 
   29 template <
class ExprType
 
   33     core::ValueAdjView<typename util::expr_traits<ExprType>::value_t,
 
   34                    typename util::shape_traits<ExprType>::shape_t>,
 
   35     core::ExprBase<BoundedInvTransformNode<ExprType, LowerType, UpperType>>
 
   38     using expr_t = ExprType;
 
   39     using expr_value_t = 
typename util::expr_traits<expr_t>::value_t;
 
   40     using expr_shape_t = 
typename util::expr_traits<expr_t>::shape_t;
 
   41     using lower_t = LowerType;
 
   42     using upper_t = UpperType;
 
   46     using typename value_adj_view_t::value_t;
 
   47     using typename value_adj_view_t::shape_t;
 
   49     using typename value_adj_view_t::ptr_pack_t;
 
   69         auto&& 
lower = lower_.feval();
 
   70         auto&& upper = upper_.feval();
 
   73             auto&& uc_val = expr_.feval();
 
   76         *v_val_  = *v_val_ % refcnt_;
 
   90         scaled_inv_logit_ = (a_val - a_lower);
 
   93         inv_logit_ = a_scaled_inv_logit / (a_upper - a_lower);
 
   96         if constexpr (util::is_scl_v<upper_t>) {
 
   97             upper_.beval(
sum(a_adj * a_inv_logit));
 
   99             upper_.beval(a_adj * a_inv_logit);
 
  102         if constexpr (util::is_scl_v<lower_t>) {
 
  103             lower_.beval(
sum(a_adj * (1. - a_inv_logit)));
 
  105             lower_.beval(a_adj * (1. - a_inv_logit));
 
  108         expr_.beval(a_adj * a_scaled_inv_logit * (1. - a_inv_logit));
 
  113         begin = expr_.bind_cache(begin);
 
  114         begin = lower_.bind_cache(begin);
 
  115         begin = upper_.bind_cache(begin);
 
  116         auto val = begin.val;
 
  117         begin.val = this->data();
 
  118         begin = this->
bind(begin);
 
  126                 expr_.bind_cache_size() +
 
  127                 lower_.bind_cache_size() +
 
  128                 upper_.bind_cache_size();
 
  132         return {0, this->
size()}; 
 
  139     util::constant_var_t<value_t, shape_t> scaled_inv_logit_;
 
  140     util::constant_var_t<value_t, shape_t> inv_logit_;
 
  142     size_t const refcnt_;
 
  145 template <
class ExprType
 
  149     core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::scl>,
 
  150     core::ExprBase<LogJBoundedInvTransformNode<ExprType, LowerType, UpperType>>
 
  153     using expr_t = ExprType;
 
  154     using lower_t = LowerType;
 
  155     using upper_t = UpperType;
 
  156     using expr_value_t = 
typename util::expr_traits<expr_t>::value_t;
 
  157     using expr_shape_t = 
typename util::shape_traits<expr_t>::shape_t;
 
  158     using lower_shape_t = 
typename util::shape_traits<lower_t>::shape_t;
 
  159     using upper_shape_t = 
typename util::shape_traits<upper_t>::shape_t;
 
  163     using typename value_adj_view_t::value_t;
 
  164     using typename value_adj_view_t::shape_t;
 
  166     using typename value_adj_view_t::ptr_pack_t;
 
  169                                 const lower_t& 
lower,
 
  170                                 const upper_t& upper,
 
  176         , c_val_(c_val, expr.rows(), expr.cols())
 
  177         , scaled_inv_logit_()
 
  190         scaled_inv_logit_ = a_c_val - a_lower;
 
  191         inv_range_ = 1. / (a_upper - a_lower);
 
  197         auto&& a_inv_logit = a_scaled_inv_logit * 
util::to_array(inv_range_);
 
  199         return this->
get() = 
sum(log(a_scaled_inv_logit * (1. - a_inv_logit)));
 
  207         if constexpr (util::is_scl_v<upper_t> &&
 
  208                       util::is_scl_v<lower_t>) {
 
  209             upper_.beval(seed * expr_.size() * a_inv_range);
 
  210             lower_.beval(seed * expr_.size() * (-a_inv_range));
 
  212         } 
else if constexpr (util::is_scl_v<upper_t>) {
 
  213             upper_.beval(seed * a_inv_range.sum());
 
  214             lower_.beval(seed * (-a_inv_range));
 
  216         } 
else if constexpr (util::is_scl_v<lower_t>) {
 
  217             upper_.beval(seed * a_inv_range);
 
  218             lower_.beval(seed * (-a_inv_range.sum()));
 
  221             assert(upper_.cols() == lower_.cols());
 
  222             assert(upper_.rows() == lower_.rows());
 
  223             upper_.beval(seed * a_inv_range);
 
  224             lower_.beval(seed * -a_inv_range);
 
  227         expr_.beval(seed * (1. - 2. * a_scaled_inv_logit * a_inv_range));    
 
  232         begin = expr_.bind_cache(begin);
 
  233         begin = lower_.bind_cache(begin);
 
  234         begin = upper_.bind_cache(begin);
 
  235         auto adj = begin.adj;
 
  237         begin = this->
bind(begin);
 
  245                 expr_.bind_cache_size() +
 
  246                 lower_.bind_cache_size() +
 
  247                 upper_.bind_cache_size();
 
  251         return {this->
size(), 0}; 
 
  255     using view_t = core::ValueView<expr_value_t, expr_shape_t>;
 
  260     util::constant_var_t<expr_value_t, expr_shape_t> scaled_inv_logit_;
 
  261     util::constant_var_t<expr_value_t, 
 
  262         util::max_shape_t<lower_shape_t, upper_shape_t> > inv_range_;