autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
5 #include <fastad_bits/util/shape_traits.hpp>
6 #include <fastad_bits/reverse/core/var_view.hpp>
13 namespace constraint {
27 template <
class CType,
class UCType>
31 using value_t =
typename CType::Scalar;
32 using mat_t = Eigen::Matrix<value_t, Eigen::Dynamic, Eigen::Dynamic>;
33 Eigen::LLT<mat_t> llt(c);
34 mat_t
lower = llt.matrixL();
35 lower.diagonal().array() =
lower.diagonal().array().log();
38 for (
int j = 0; j <
lower.cols(); ++j) {
39 for (
int i = j; i <
lower.rows(); ++i, ++k) {
52 template <
class LowerType,
class UCType,
class CType>
60 template <
class ValueType>
67 using uc_view_t = ad::util::shape_to_raw_view_t<value_t, vec>;
68 using view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
71 static_assert(util::is_cont_v<value_t>);
90 constraint_t::transform(c_val_, uc_val_);
104 constraint_t::inv_transform(lower_, uc_val_, c_val_);
106 *v_val_ = *v_val_ % refcnt;
152 template <
class CurrPtrPack,
class PtrPack>
155 size_t refcnt)
const {
156 ad::VarView<value_t, ad::vec> uc_view(curr_pack.uc_val,
161 curr_pack.c_val + size_c(),
173 template <
class CurrPtrPack,
class PtrPack>
175 const PtrPack&)
const {
176 ad::VarView<value_t, ad::vec> uc_view(curr_pack.uc_val,
187 template <
class GenType,
class ContDist>
188 void init(GenType&, ContDist&) {
201 constexpr
size_t size_uc()
const {
return uc_val_.size(); }
202 constexpr
size_t rows_uc()
const {
return uc_val_.rows(); }
203 constexpr
size_t cols_uc()
const {
return 1; }
204 constexpr
size_t size_c()
const {
return c_val_.size(); }
205 constexpr
size_t rows_c()
const {
return c_val_.rows(); }
206 constexpr
size_t cols_c()
const {
return c_val_.cols(); }
221 template <
class CurrPtrPack,
class PtrPack>
222 void bind(
const CurrPtrPack& curr_pack,
225 util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
226 util::bind(lower_, curr_pack.c_val, rows_c(), cols_c());
227 util::bind(c_val_, curr_pack.c_val + size_c(), rows_c(), cols_c());
auto & get(T &&x)
Definition: value.hpp:52
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
static constexpr size_t size(size_t rows)
Definition: pos_def.hpp:21
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
constexpr size_t size(const T &x)
Definition: value.hpp:22
constexpr auto pos_def()
Definition: pos_def.hpp:241
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
static constexpr void inv_transform(LowerType &lower, const UCType &uc, CType &c)
Definition: pos_def.hpp:53
Definition: pos_def.hpp:15
Definition: bounded.hpp:11
ad::mat mat
Definition: shape_traits.hpp:18
constexpr void cov_inv_transform(LowerType &lower, const UCType &uc, CType &c)
Definition: cov_inv_transform.hpp:12
constexpr size_t rows(const T &x)
Definition: value.hpp:32
static constexpr void transform(const CType &c, UCType &uc)
Definition: pos_def.hpp:28