2 #include <fastad_bits/reverse/core/expr_base.hpp>
3 #include <fastad_bits/reverse/core/value_adj_view.hpp>
4 #include <fastad_bits/util/type_traits.hpp>
5 #include <fastad_bits/util/size_pack.hpp>
6 #include <fastad_bits/util/value.hpp>
11 template <
class LowerType,
class UCType,
class CType>
17 for (
int j = 0; j <
lower.cols(); ++j) {
18 lower(j,j) = std::exp(uc(k));
20 for (
int i = j+1; i <
lower.rows(); ++i, ++k) {
27 template <
class ExprType>
29 core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::mat>,
30 core::ExprBase<CovInvTransformNode<ExprType>>
33 using expr_t = ExprType;
34 using expr_value_t =
typename util::expr_traits<expr_t>::value_t;
35 using expr_shape_t =
typename util::shape_traits<expr_t>::shape_t;
37 static_assert(util::is_vec_v<expr_t>);
41 using typename value_adj_view_t::value_t;
42 using typename value_adj_view_t::shape_t;
44 using typename value_adj_view_t::ptr_pack_t;
56 , flattened_adj_(expr.size())
65 auto&& uc_val_ = expr_.feval();
68 *v_val_ = *v_val_ % refcnt_;
79 adj_ = (this->get_adj().transpose() + this->get_adj()) * lower_;
80 adj_.diagonal().array() *= lower_.diagonal().array();
83 for (
size_t j = 0; j < this->
cols(); ++j) {
84 for (
size_t i = j; i < this->
rows(); ++i, ++k) {
85 flattened_adj_(k) = adj_(i,j);
89 expr_.beval(a_flattened_adj);
94 begin = expr_.bind_cache(begin);
96 begin.val = this->data();
97 begin = this->
bind(begin);
105 expr_.bind_cache_size();
109 return {0, this->
size()};
113 using mat_view_t = util::shape_to_raw_view_t<value_t, shape_t>;
116 util::constant_var_t<value_t, shape_t> adj_;
117 util::constant_var_t<value_t, expr_shape_t> flattened_adj_;
119 size_t const refcnt_;
122 template <
class ExprType>
124 core::ValueAdjView<typename util::expr_traits<ExprType>::value_t, ad::scl>,
125 core::ExprBase<LogJCovInvTransformNode<ExprType>>
128 using expr_t = ExprType;
129 using expr_value_t =
typename util::expr_traits<expr_t>::value_t;
130 using expr_shape_t =
typename util::shape_traits<expr_t>::shape_t;
132 static_assert(util::is_vec_v<expr_t>);
136 using typename value_adj_view_t::value_t;
137 using typename value_adj_view_t::shape_t;
139 using typename value_adj_view_t::ptr_pack_t;
146 , flattened_adj_(expr.size())
151 auto&& expr = expr_.feval();
152 size_t weight = rows_ + 1;
156 for (
size_t k = 0; k < rows_; ++k, --weight, --incr) {
157 this->
get() += weight * expr(pos);
165 size_t weight = rows_ + 1;
168 flattened_adj_.setZero();
169 for (
size_t k = 0; k < rows_; ++k, --weight, --incr) {
170 flattened_adj_(pos) = weight;
173 expr_.beval(seed * flattened_adj_.array());
178 begin = expr_.bind_cache(begin);
179 auto adj = begin.adj;
181 begin = this->
bind(begin);
189 expr_.bind_cache_size();
193 return {this->
size(), 0};
199 util::constant_var_t<value_t, expr_shape_t> flattened_adj_;