autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
4 #include <fastad_bits/util/shape_traits.hpp>
5 #include <fastad_bits/reverse/core/var_view.hpp>
6 #include <fastad_bits/reverse/core/sum.hpp>
15 namespace constraint {
17 template <
class ExprType>
21 using lower_t = ExprType;
22 static_assert(util::is_var_expr_v<lower_t>);
38 if constexpr (std::is_arithmetic_v<T>) {
39 uc = std::log(c - lower_.get());
41 if constexpr (util::is_scl_v<lower_t>) {
42 uc = (c.array() - lower_.get()).log().matrix();
44 uc = (c.array() - lower_.get().array()).log().matrix();
59 template <
class UCViewType
60 ,
class CurrPtrPackType
63 const CurrPtrPackType& curr_pack,
64 const PtrPackType& pack,
67 auto lower = lower_.ad(pack);
75 template <
class UCViewType>
83 template <
class PtrPack>
84 void bind(
const PtrPack& pack) {
85 if constexpr (lower_t::has_param) {
95 template <
class ValueType,
class ShapeType,
class ExprType>
102 using uc_view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
106 static_assert(util::is_cont_v<value_t>);
107 static_assert(util::is_scl_v<ExprType> ||
147 *v_val_ = *v_val_ % refcnt;
160 template <
class CurrPtrPack,
class PtrPack>
163 size_t refcnt)
const {
164 ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
168 return constraint_.inv_transform_ad(uc_view, curr_pack, pack, refcnt);
175 template <
class CurrPtrPack,
class PtrPack>
177 const PtrPack&)
const {
178 ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
182 return constraint_.logj_inv_transform_ad(uc_view);
188 template <
class GenType,
class ContDist>
189 void init(GenType& gen, ContDist& dist) {
190 if constexpr (std::is_same_v<shape_t, scl>) {
193 util::get(uc_val_) = var_t::NullaryExpr(rows_uc(), cols_uc(),
194 [&](
size_t,
size_t) {
return dist(gen); });
205 if (curr_refcnt == 1) constraint_.activate_refcnt();
235 template <
class CurrPtrPack,
class PtrPack>
236 void bind(
const CurrPtrPack& curr_pack,
239 util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
240 util::bind(c_val_, curr_pack.c_val, rows_c(), cols_c());
242 constraint_.bind(pack);
249 constraint_t constraint_;
255 template <
class LowerType>
256 constexpr
inline auto lower(
const LowerType& expr)
259 lower_t wrap_expr = expr;
constexpr void inv_transform(const T &uc, T &c)
Definition: lower.hpp:55
auto & get(T &&x)
Definition: value.hpp:52
Lower(const lower_t &lower)
Definition: lower.hpp:28
void bind(const PtrPack &pack)
Definition: lower.hpp:84
auto sum(const T &x)
Definition: value.hpp:8
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
constexpr size_t cols(const T &x)
Definition: value.hpp:42
typename util::var_expr_traits< lower_t >::value_t value_t
Definition: lower.hpp:25
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
constexpr size_t size(const T &x)
Definition: value.hpp:22
constexpr void transform(const T &c, T &uc) const
Definition: lower.hpp:35
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
typename util::shape_traits< lower_t >::shape_t shape_t
Definition: lower.hpp:26
auto logj_inv_transform_ad(const UCViewType &uc_view) const
Definition: lower.hpp:76
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
void activate_refcnt() const
Definition: lower.hpp:81
auto inv_transform_ad(const UCViewType &uc_view, const CurrPtrPackType &curr_pack, const PtrPackType &pack, size_t refcnt) const
Definition: lower.hpp:62
constexpr void lower_inv_transform(const UCType &uc, const LowerType &lower, CType &c)
Definition: lower_inv_transform.hpp:13
Definition: bounded.hpp:11
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
constexpr size_t rows(const T &x)
Definition: value.hpp:32
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9