autoppl  v0.8
A C++ template library for probabilistic programming
normal.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/stat/normal.hpp>
6 #include <autoppl/math/math.hpp>
7 
8 #define PPL_NORMAL_PARAM_SHAPE \
9  "Normal distribution mean must either be a scalar or vector. "
10 
11 namespace ppl {
12 namespace expr {
13 namespace dist {
14 namespace details {
15 
20 template <class MeanType
21  , class SigmaType>
23 {
24  static constexpr bool value =
25  util::is_shape_v<MeanType> &&
26  util::is_shape_v<SigmaType> &&
27  util::is_scl_v<MeanType> &&
28  util::is_scl_v<SigmaType>;
29 };
30 
35 template <class MeanType
36  , class SigmaType>
38 {
39  static constexpr bool value =
40  util::is_shape_v<MeanType> &&
41  util::is_shape_v<SigmaType> &&
42  !util::is_mat_v<MeanType>;
43 };
44 
48 template <class VarType
49  , class MeanType
50  , class SigmaType>
52 {
53  static constexpr bool value =
54  util::is_shape_v<VarType> &&
55  (
56  (util::is_scl_v<VarType> &&
58  (util::is_vec_v<VarType> &&
60  );
61 };
62 
63 template <class MeanType
64  , class SigmaType>
65 inline constexpr bool normal_valid_param_dim_case_1_v =
67 
68 template <class MeanType
69  , class SigmaType>
70 inline constexpr bool normal_valid_param_dim_case_2_v =
72 
73 template <class VarType
74  , class MeanType
75  , class SigmaType>
76 inline constexpr bool normal_valid_dim_v =
78 
79 } // namespace details
80 
93 template <class MeanType
94  , class SigmaType>
95 struct Normal:
96  util::DistExprBase<Normal<MeanType, SigmaType>>
97 {
98 private:
99  using mean_t = MeanType;
100  using sigma_t = SigmaType;
101 
102  static_assert(util::is_var_expr_v<mean_t>);
103  static_assert(util::is_var_expr_v<sigma_t>);
104  static_assert(details::normal_valid_param_dim_case_2_v<mean_t, sigma_t>,
107  );
108 
109 public:
112  using typename base_t::dist_value_t;
113 
114  Normal(const mean_t& mean,
115  const sigma_t& sigma)
116  : mean_{mean}, sigma_{sigma}
117  {}
118 
119  template <class XType>
120  dist_value_t pdf(const XType& x)
121  {
122  static_assert(util::is_dist_assignable_v<XType>);
123  static_assert(details::normal_valid_dim_v<XType, mean_t, sigma_t>,
125  return math::normal_pdf(x.get(), mean_.eval(), sigma_.eval());
126  }
127 
128  template <class XType>
129  dist_value_t log_pdf(const XType& x)
130  {
131  static_assert(util::is_dist_assignable_v<XType>);
132  static_assert(details::normal_valid_dim_v<XType, mean_t, sigma_t>,
134  return math::normal_log_pdf(x.get(), mean_.eval(), sigma_.eval());
135  }
136 
137  template <class XType
138  , class PtrPackType>
139  auto ad_log_pdf(const XType& x,
140  const PtrPackType& pack) const
141  {
142  static_assert(util::is_dist_assignable_v<XType>);
143  static_assert(details::normal_valid_dim_v<XType, mean_t, sigma_t>,
145  return ad::normal_adj_log_pdf(x.ad(pack),
146  mean_.ad(pack),
147  sigma_.ad(pack));
148  }
149 
150  template <class PtrPackType>
151  void bind(const PtrPackType& pack)
152  {
153  static_cast<void>(pack);
154  if constexpr (mean_t::has_param) {
155  mean_.bind(pack);
156  }
157  if constexpr (sigma_t::has_param) {
158  sigma_.bind(pack);
159  }
160  }
161 
162  void activate_refcnt() const
163  {
164  mean_.activate_refcnt();
165  sigma_.activate_refcnt();
166  }
167 
168  template <class XType, class GenType>
169  bool prune(XType&, GenType&) const { return false; }
170 
171 private:
172  mean_t mean_;
173  sigma_t sigma_;
174 };
175 
176 } // namespace dist
177 } // namespace expr
178 
184 template <class MeanType, class SDType
185  , class = std::enable_if_t<
186  util::is_valid_dist_param_v<MeanType> &&
187  util::is_valid_dist_param_v<SDType>
188  > >
189 inline constexpr auto normal(const MeanType& mean_expr,
190  const SDType& sd_expr)
191 {
192  using mean_t = util::convert_to_param_t<MeanType>;
194 
195  mean_t wrap_mean_expr = mean_expr;
196  sd_t wrap_sd_expr = sd_expr;
197 
198  return expr::dist::Normal(wrap_mean_expr, wrap_sd_expr);
199 }
200 
201 } // namespace ppl
202 
203 #undef PPL_NORMAL_PARAM_SHAPE
ppl::expr::dist::Normal::activate_refcnt
void activate_refcnt() const
Definition: normal.hpp:162
ppl::expr::dist::Normal::Normal
Normal(const mean_t &mean, const sigma_t &sigma)
Definition: normal.hpp:114
PPL_NORMAL_PARAM_SHAPE
#define PPL_NORMAL_PARAM_SHAPE
Definition: normal.hpp:8
ppl::expr::dist::details::normal_valid_param_dim_case_1::value
static constexpr bool value
Definition: normal.hpp:24
ppl::expr::dist::details::normal_valid_dim
Definition: normal.hpp:52
ppl::expr::dist::Normal::bind
void bind(const PtrPackType &pack)
Definition: normal.hpp:151
ppl::math::normal_pdf
dist_value_t normal_pdf(const XType &x, const MeanType &mean, const SigmaType &sigma)
Definition: density.hpp:40
ppl::util::DistExprBase
Definition: dist_expr_traits.hpp:24
ppl::normal
constexpr auto normal(const MeanType &mean_expr, const SDType &sd_expr)
Definition: normal.hpp:189
ppl::expr::dist::details::normal_valid_dim::value
static constexpr bool value
Definition: normal.hpp:53
ppl::expr::dist::Normal::value_t
util::cont_param_t value_t
Definition: normal.hpp:110
ppl::expr::dist::Normal
Definition: normal.hpp:97
ppl::expr::dist::details::normal_valid_param_dim_case_2
Definition: normal.hpp:38
ppl::expr::dist::Normal::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::expr::dist::Normal::log_pdf
dist_value_t log_pdf(const XType &x)
Definition: normal.hpp:129
ppl::expr::dist::details::normal_valid_param_dim_case_1_v
constexpr bool normal_valid_param_dim_case_1_v
Definition: normal.hpp:65
ppl::expr::dist::details::normal_valid_param_dim_case_1
Definition: normal.hpp:23
ppl::expr::dist::details::normal_valid_param_dim_case_2_v
constexpr bool normal_valid_param_dim_case_2_v
Definition: normal.hpp:70
ppl::expr::dist::details::normal_valid_dim_v
constexpr bool normal_valid_dim_v
Definition: normal.hpp:76
ppl::math::normal_log_pdf
dist_value_t normal_log_pdf(const XType &x, const MeanType &mean, const SigmaType &sigma)
Definition: density.hpp:149
ppl::util::DistExprBase::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::util::cont_param_t
double cont_param_t
Definition: dist_expr_traits.hpp:14
dist_utils.hpp
ppl::expr::dist::details::normal_valid_param_dim_case_2::value
static constexpr bool value
Definition: normal.hpp:39
ppl
Definition: bounded.hpp:11
ppl::expr::dist::Normal::prune
bool prune(XType &, GenType &) const
Definition: normal.hpp:169
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
density.hpp
traits.hpp
math.hpp
ppl::expr::dist::Normal::ad_log_pdf
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: normal.hpp:139
PPL_DIST_SHAPE_MISMATCH
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
ppl::expr::dist::Normal::pdf
dist_value_t pdf(const XType &x)
Definition: normal.hpp:120