autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
4 #include <fastad_bits/util/shape_traits.hpp>
5 #include <fastad_bits/reverse/core/var_view.hpp>
13 namespace constraint {
15 template <
class LowerType,
class UpperType>
19 using lower_t = LowerType;
20 using upper_t = UpperType;
22 static_assert(util::is_var_expr_v<lower_t>);
23 static_assert(util::is_var_expr_v<upper_t>);
24 static_assert(util::is_scl_v<lower_t> ||
25 util::is_scl_v<upper_t> ||
45 if constexpr (std::is_arithmetic_v<T>) {
46 uc = std::log((c - lower_.get()) / (upper_.get() - c));
51 uc = (ca - alower) / (aupper - ca).log().matrix();
63 template <
class UCViewType
64 ,
class CurrPtrPackType
67 const CurrPtrPackType& curr_pack,
68 const PtrPackType& pack,
71 auto&&
lower = lower_.ad(pack);
72 auto&& upper = upper_.ad(pack);
81 template <
class UCViewType
85 const CurrPtrPack& curr_pack,
86 const PtrPack& pack)
const
88 auto lower = lower_.ad(pack);
89 auto upper = upper_.ad(pack);
97 lower_.activate_refcnt();
98 upper_.activate_refcnt();
101 template <
class PtrPack>
102 void bind(
const PtrPack& pack) {
103 if constexpr (lower_t::has_param) {
106 if constexpr (upper_t::has_param) {
117 template <
class ValueType
127 using uc_view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
131 static_assert(util::is_cont_v<value_t>);
163 *v_val_ = *v_val_ % refcnt;
189 template <
class CurrPtrPack,
class PtrPack>
192 size_t refcnt)
const {
193 ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
197 return constraint_.inv_transform_ad(uc_view, curr_pack, pack, refcnt);
206 template <
class CurrPtrPack,
class PtrPack>
208 const PtrPack& pack)
const {
209 ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val,
213 return constraint_.logj_inv_transform_ad(uc_view, curr_pack, pack);
219 template <
class GenType,
class ContDist>
220 void init(GenType& gen, ContDist& dist) {
221 if constexpr (std::is_same_v<shape_t, scl>) {
224 util::get(uc_val_) = var_t::NullaryExpr(rows_uc(), cols_uc(),
225 [&](
size_t,
size_t) {
return dist(gen); });
230 if (curr_refcnt == 1) constraint_.activate_refcnt();
260 template <
class CurrPtrPack,
class PtrPack>
261 void bind(
const CurrPtrPack& curr_pack,
264 util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
265 util::bind(c_val_, curr_pack.c_val, rows_c(), cols_c());
267 constraint_.bind(pack);
274 constraint_t constraint_;
281 template <
class LowerType
284 const UpperType& upper)
288 lower_t wrap_lower =
lower;
289 upper_t wrap_upper = upper;
auto & get(T &&x)
Definition: value.hpp:52
constexpr void bounded_inv_transform(const UCType &uc, const LowerType &lower, const UpperType &upper, CType &c)
Definition: bounded_inv_transform.hpp:16
void activate_refcnt() const
Definition: bounded.hpp:96
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
auto logj_inv_transform_ad(const UCViewType &uc_view, const CurrPtrPack &curr_pack, const PtrPack &pack) const
Definition: bounded.hpp:84
constexpr size_t cols(const T &x)
Definition: value.hpp:42
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
constexpr size_t size(const T &x)
Definition: value.hpp:22
constexpr auto bounded(const LowerType &lower, const UpperType &upper)
Definition: bounded.hpp:283
void bind(const PtrPack &pack)
Definition: bounded.hpp:102
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
constexpr void inv_transform(const T &uc, T &c)
Definition: bounded.hpp:59
Bounded(const lower_t &lower, const upper_t &upper)
Definition: bounded.hpp:32
constexpr auto to_array(const T &x)
Definition: value.hpp:74
Definition: bounded.hpp:11
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
constexpr size_t rows(const T &x)
Definition: value.hpp:32
auto inv_transform_ad(const UCViewType &uc_view, const CurrPtrPackType &curr_pack, const PtrPackType &pack, size_t refcnt) const
Definition: bounded.hpp:66
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9
constexpr void transform(const T &c, T &uc) const
Definition: bounded.hpp:42
Definition: bounded.hpp:17