autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
3 #include <fastad_bits/reverse/stat/bernoulli.hpp>
8 #define PPL_BERNOULLI_PARAM_SHAPE \
9 "Bernoulli distribution probability must either be a scalar or vector. " \
20 template <
class PType>
24 util::is_shape_v<PType> &&
25 !util::is_mat_v<PType>;
32 template <
class VarType
37 util::is_shape_v<VarType> &&
39 (util::is_scl_v<VarType> &&
40 util::is_scl_v<PType>) ||
41 (util::is_vec_v<VarType> &&
46 template <
class PType>
50 template <
class VarType
71 template <
class PType>
77 static_assert(util::is_var_expr_v<p_t>);
78 static_assert(details::bern_valid_param_dim_v<p_t>,
92 template <
class XType>
95 static_assert(util::is_dist_assignable_v<XType>);
96 static_assert(details::bern_valid_dim_v<XType, p_t>,
101 template <
class XType>
104 static_assert(util::is_dist_assignable_v<XType>);
105 static_assert(details::bern_valid_dim_v<XType, p_t>,
110 template <
class XType
113 const PtrPackType& pack)
const
115 return ad::bernoulli_adj_log_pdf(x.ad(pack),
119 template <
class PtrPackType>
120 void bind(
const PtrPackType& pack)
122 static_cast<void>(pack);
123 if constexpr (p_t::has_param) {
129 { p_.activate_refcnt(); }
131 template <
class XType,
class GenType>
132 bool prune(XType& x, GenType&)
const {
133 using x_t = std::decay_t<XType>;
134 static_assert(util::is_param_v<x_t>);
135 if constexpr (util::is_scl_v<x_t>) {
136 bool needs_prune = (x.get() != 0) && (x.get() != 1);
137 if (needs_prune) x.get() = 0;
139 }
else if constexpr (util::is_vec_v<x_t>){
140 auto xa = x.get().array();
141 bool needs_prune = ((xa != 0).min(xa != 1)).any();
142 if (needs_prune) x.get().setZero();
159 template <
class ProbType
160 ,
class = std::enable_if_t<
161 util::is_valid_dist_param_v<ProbType>
166 p_t wrap_p_expr = p_expr;
172 #undef PPL_BERNOULLI_PARAM_SHAPE
dist_value_t pdf(const XType &x)
Definition: bernoulli.hpp:93
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
static constexpr bool value
Definition: bernoulli.hpp:23
Definition: dist_expr_traits.hpp:24
dist_value_t bernoulli_pdf(const XType &x, const PType &p)
Definition: density.hpp:506
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: bernoulli.hpp:112
constexpr bool bern_valid_dim_v
Definition: bernoulli.hpp:52
int32_t disc_param_t
Definition: dist_expr_traits.hpp:15
#define PPL_BERNOULLI_PARAM_SHAPE
Definition: bernoulli.hpp:8
static constexpr bool value
Definition: bernoulli.hpp:36
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
void activate_refcnt() const
Definition: bernoulli.hpp:128
dist_value_t bernoulli_log_pdf(const XType &x, const PType &p)
Definition: density.hpp:554
void bind(const PtrPackType &pack)
Definition: bernoulli.hpp:120
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
Definition: bernoulli.hpp:35
Definition: bounded.hpp:11
Bernoulli(const p_t &p)
Definition: bernoulli.hpp:89
dist_value_t log_pdf(const XType &x)
Definition: bernoulli.hpp:102
constexpr bool bern_valid_param_dim_v
Definition: bernoulli.hpp:47
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
Definition: bernoulli.hpp:73
constexpr auto bernoulli(const ProbType &p_expr)
Definition: bernoulli.hpp:163
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
bool prune(XType &x, GenType &) const
Definition: bernoulli.hpp:132
util::disc_param_t value_t
Definition: bernoulli.hpp:84
Definition: bernoulli.hpp:22
typename util::var_expr_traits< p_t >::value_t param_value_t
Definition: bernoulli.hpp:85