| 
    autoppl
    v0.8
    
   A C++ template library for probabilistic programming 
   | 
 
 
 
 
Go to the documentation of this file.
    4 #include <fastad_bits/util/shape_traits.hpp> 
    5 #include <fastad_bits/reverse/core/var_view.hpp> 
    6 #include <fastad_bits/reverse/core/sum.hpp> 
   15 namespace constraint {
 
   17 template <
class ExprType>
 
   21     using lower_t = ExprType;
 
   22     static_assert(util::is_var_expr_v<lower_t>);
 
   38         if constexpr (std::is_arithmetic_v<T>) {
 
   39             uc = std::log(c - lower_.get()); 
 
   41             if constexpr (util::is_scl_v<lower_t>) {
 
   42                 uc = (c.array() - lower_.get()).log().matrix();
 
   44                 uc = (c.array() - lower_.get().array()).log().matrix();            
 
   59     template <
class UCViewType
 
   60             , 
class CurrPtrPackType
 
   63                           const CurrPtrPackType& curr_pack,
 
   64                           const PtrPackType& pack,
 
   67         auto lower = lower_.ad(pack);
 
   75     template <
class UCViewType>
 
   83     template <
class PtrPack>
 
   84     void bind(
const PtrPack& pack) { 
 
   85         if constexpr (lower_t::has_param) {
 
   95 template <
class ValueType, 
class ShapeType, 
class ExprType>
 
  102     using uc_view_t = ad::util::shape_to_raw_view_t<value_t, shape_t>;
 
  106     static_assert(util::is_cont_v<value_t>);
 
  107     static_assert(util::is_scl_v<ExprType> ||
 
  147         *v_val_ = *v_val_ % refcnt;
 
  160     template <
class CurrPtrPack, 
class PtrPack>
 
  163                           size_t refcnt)
 const {
 
  164         ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val, 
 
  168         return constraint_.inv_transform_ad(uc_view, curr_pack, pack, refcnt);
 
  175     template <
class CurrPtrPack, 
class PtrPack>
 
  177                                const PtrPack&)
 const {
 
  178         ad::VarView<value_t, shape_t> uc_view(curr_pack.uc_val, 
 
  182         return constraint_.logj_inv_transform_ad(uc_view);
 
  188     template <
class GenType, 
class ContDist>
 
  189     void init(GenType& gen, ContDist& dist) {
 
  190         if constexpr (std::is_same_v<shape_t, scl>) {
 
  193             util::get(uc_val_) = var_t::NullaryExpr(rows_uc(), cols_uc(),
 
  194                                     [&](
size_t, 
size_t) { 
return dist(gen); });
 
  205         if (curr_refcnt == 1) constraint_.activate_refcnt();
 
  235     template <
class CurrPtrPack, 
class PtrPack>
 
  236     void bind(
const CurrPtrPack& curr_pack,
 
  239         util::bind(uc_val_, curr_pack.uc_val, rows_uc(), cols_uc());
 
  240         util::bind(c_val_, curr_pack.c_val, rows_c(), cols_c());
 
  242         constraint_.bind(pack);
 
  249     constraint_t constraint_;
 
  255 template <
class LowerType>
 
  256 constexpr 
inline auto lower(
const LowerType& expr)
 
  259     lower_t wrap_expr = expr;
 
  
 
constexpr void inv_transform(const T &uc, T &c)
Definition: lower.hpp:55
 
auto & get(T &&x)
Definition: value.hpp:52
 
Lower(const lower_t &lower)
Definition: lower.hpp:28
 
void bind(const PtrPack &pack)
Definition: lower.hpp:84
 
auto sum(const T &x)
Definition: value.hpp:8
 
void bind(T &x, ValPtrType begin, size_t rows=1, size_t cols=1)
Definition: value.hpp:61
 
constexpr size_t cols(const T &x)
Definition: value.hpp:42
 
typename util::var_expr_traits< lower_t >::value_t value_t
Definition: lower.hpp:25
 
constexpr auto lower(const LowerType &expr)
Definition: lower.hpp:256
 
constexpr size_t size(const T &x)
Definition: value.hpp:22
 
constexpr void transform(const T &c, T &uc) const
Definition: lower.hpp:35
 
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
 
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
 
typename util::shape_traits< lower_t >::shape_t shape_t
Definition: lower.hpp:26
 
auto logj_inv_transform_ad(const UCViewType &uc_view) const
Definition: lower.hpp:76
 
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
 
void activate_refcnt() const
Definition: lower.hpp:81
 
auto inv_transform_ad(const UCViewType &uc_view, const CurrPtrPackType &curr_pack, const PtrPackType &pack, size_t refcnt) const
Definition: lower.hpp:62
 
constexpr void lower_inv_transform(const UCType &uc, const LowerType &lower, CType &c)
Definition: lower_inv_transform.hpp:13
 
Definition: bounded.hpp:11
 
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
 
constexpr size_t rows(const T &x)
Definition: value.hpp:32
 
constexpr auto make_val(size_t rows=1, size_t cols=1)
Definition: value.hpp:9