autoppl  v0.8
A C++ template library for probabilistic programming
bernoulli.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <cassert>
3 #include <fastad_bits/reverse/stat/bernoulli.hpp>
7 
8 #define PPL_BERNOULLI_PARAM_SHAPE \
9  "Bernoulli distribution probability must either be a scalar or vector. " \
10 
11 namespace ppl {
12 namespace expr {
13 namespace dist {
14 namespace details {
15 
20 template <class PType>
22 {
23  static constexpr bool value =
24  util::is_shape_v<PType> &&
25  !util::is_mat_v<PType>;
26 };
27 
32 template <class VarType
33  , class PType>
35 {
36  static constexpr bool value =
37  util::is_shape_v<VarType> &&
38  (
39  (util::is_scl_v<VarType> &&
40  util::is_scl_v<PType>) ||
41  (util::is_vec_v<VarType> &&
43  );
44 };
45 
46 template <class PType>
47 inline constexpr bool bern_valid_param_dim_v =
49 
50 template <class VarType
51  , class PType>
52 inline constexpr bool bern_valid_dim_v =
54 
55 } // namespace details
56 
71 template <class PType>
72 struct Bernoulli : util::DistExprBase<Bernoulli<PType>>
73 {
74 private:
75  using p_t = PType;
76 
77  static_assert(util::is_var_expr_v<p_t>);
78  static_assert(details::bern_valid_param_dim_v<p_t>,
81  );
82 
83 public:
87  using typename base_t::dist_value_t;
88 
89  Bernoulli(const p_t& p)
90  : p_{p} {}
91 
92  template <class XType>
93  dist_value_t pdf(const XType& x)
94  {
95  static_assert(util::is_dist_assignable_v<XType>);
96  static_assert(details::bern_valid_dim_v<XType, p_t>,
98  return math::bernoulli_pdf(x.get(), p_.eval());
99  }
100 
101  template <class XType>
102  dist_value_t log_pdf(const XType& x)
103  {
104  static_assert(util::is_dist_assignable_v<XType>);
105  static_assert(details::bern_valid_dim_v<XType, p_t>,
107  return math::bernoulli_log_pdf(x.get(), p_.eval());
108  }
109 
110  template <class XType
111  , class PtrPackType>
112  auto ad_log_pdf(const XType& x,
113  const PtrPackType& pack) const
114  {
115  return ad::bernoulli_adj_log_pdf(x.ad(pack),
116  p_.ad(pack));
117  }
118 
119  template <class PtrPackType>
120  void bind(const PtrPackType& pack)
121  {
122  static_cast<void>(pack);
123  if constexpr (p_t::has_param) {
124  p_.bind(pack);
125  }
126  }
127 
128  void activate_refcnt() const
129  { p_.activate_refcnt(); }
130 
131  template <class XType, class GenType>
132  bool prune(XType& x, GenType&) const {
133  using x_t = std::decay_t<XType>;
134  static_assert(util::is_param_v<x_t>);
135  if constexpr (util::is_scl_v<x_t>) {
136  bool needs_prune = (x.get() != 0) && (x.get() != 1);
137  if (needs_prune) x.get() = 0;
138  return needs_prune;
139  } else if constexpr (util::is_vec_v<x_t>){
140  auto xa = x.get().array();
141  bool needs_prune = ((xa != 0).min(xa != 1)).any();
142  if (needs_prune) x.get().setZero();
143  return needs_prune;
144  }
145  }
146 
147 private:
148  p_t p_;
149 };
150 
151 } // namespace dist
152 } // namespace expr
153 
159 template <class ProbType
160  , class = std::enable_if_t<
161  util::is_valid_dist_param_v<ProbType>
162  > >
163 inline constexpr auto bernoulli(const ProbType& p_expr)
164 {
166  p_t wrap_p_expr = p_expr;
167  return expr::dist::Bernoulli(wrap_p_expr);
168 }
169 
170 } // namespace ppl
171 
172 #undef PPL_BERNOULLI_PARAM_SHAPE
ppl::expr::dist::Bernoulli::pdf
dist_value_t pdf(const XType &x)
Definition: bernoulli.hpp:93
ppl::expr::dist::Bernoulli::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::expr::dist::details::bern_valid_param_dim::value
static constexpr bool value
Definition: bernoulli.hpp:23
ppl::util::DistExprBase
Definition: dist_expr_traits.hpp:24
ppl::math::bernoulli_pdf
dist_value_t bernoulli_pdf(const XType &x, const PType &p)
Definition: density.hpp:506
ppl::expr::dist::Bernoulli::ad_log_pdf
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: bernoulli.hpp:112
ppl::expr::dist::details::bern_valid_dim_v
constexpr bool bern_valid_dim_v
Definition: bernoulli.hpp:52
ppl::util::disc_param_t
int32_t disc_param_t
Definition: dist_expr_traits.hpp:15
PPL_BERNOULLI_PARAM_SHAPE
#define PPL_BERNOULLI_PARAM_SHAPE
Definition: bernoulli.hpp:8
ppl::expr::dist::details::bern_valid_dim::value
static constexpr bool value
Definition: bernoulli.hpp:36
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::dist::Bernoulli::activate_refcnt
void activate_refcnt() const
Definition: bernoulli.hpp:128
ppl::math::bernoulli_log_pdf
dist_value_t bernoulli_log_pdf(const XType &x, const PType &p)
Definition: density.hpp:554
ppl::expr::dist::Bernoulli::bind
void bind(const PtrPackType &pack)
Definition: bernoulli.hpp:120
ppl::util::DistExprBase::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
dist_utils.hpp
ppl::expr::dist::details::bern_valid_dim
Definition: bernoulli.hpp:35
ppl
Definition: bounded.hpp:11
ppl::expr::dist::Bernoulli::Bernoulli
Bernoulli(const p_t &p)
Definition: bernoulli.hpp:89
ppl::expr::dist::Bernoulli::log_pdf
dist_value_t log_pdf(const XType &x)
Definition: bernoulli.hpp:102
ppl::expr::dist::details::bern_valid_param_dim_v
constexpr bool bern_valid_param_dim_v
Definition: bernoulli.hpp:47
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::dist::Bernoulli
Definition: bernoulli.hpp:73
density.hpp
traits.hpp
ppl::bernoulli
constexpr auto bernoulli(const ProbType &p_expr)
Definition: bernoulli.hpp:163
PPL_DIST_SHAPE_MISMATCH
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
ppl::expr::dist::Bernoulli::prune
bool prune(XType &x, GenType &) const
Definition: bernoulli.hpp:132
ppl::expr::dist::Bernoulli::value_t
util::disc_param_t value_t
Definition: bernoulli.hpp:84
ppl::expr::dist::details::bern_valid_param_dim
Definition: bernoulli.hpp:22
ppl::expr::dist::Bernoulli::param_value_t
typename util::var_expr_traits< p_t >::value_t param_value_t
Definition: bernoulli.hpp:85