autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
2 #include <fastad_bits/reverse/stat/normal.hpp>
8 #define PPL_NORMAL_PARAM_SHAPE \
9 "Normal distribution mean must either be a scalar or vector. "
20 template <
class MeanType
25 util::is_shape_v<MeanType> &&
26 util::is_shape_v<SigmaType> &&
27 util::is_scl_v<MeanType> &&
28 util::is_scl_v<SigmaType>;
35 template <
class MeanType
40 util::is_shape_v<MeanType> &&
41 util::is_shape_v<SigmaType> &&
42 !util::is_mat_v<MeanType>;
48 template <
class VarType
54 util::is_shape_v<VarType> &&
56 (util::is_scl_v<VarType> &&
58 (util::is_vec_v<VarType> &&
63 template <
class MeanType
68 template <
class MeanType
73 template <
class VarType
93 template <
class MeanType
99 using mean_t = MeanType;
100 using sigma_t = SigmaType;
102 static_assert(util::is_var_expr_v<mean_t>);
103 static_assert(util::is_var_expr_v<sigma_t>);
104 static_assert(details::normal_valid_param_dim_case_2_v<mean_t, sigma_t>,
115 const sigma_t& sigma)
116 : mean_{mean}, sigma_{sigma}
119 template <
class XType>
122 static_assert(util::is_dist_assignable_v<XType>);
123 static_assert(details::normal_valid_dim_v<XType, mean_t, sigma_t>,
128 template <
class XType>
131 static_assert(util::is_dist_assignable_v<XType>);
132 static_assert(details::normal_valid_dim_v<XType, mean_t, sigma_t>,
137 template <
class XType
140 const PtrPackType& pack)
const
142 static_assert(util::is_dist_assignable_v<XType>);
143 static_assert(details::normal_valid_dim_v<XType, mean_t, sigma_t>,
145 return ad::normal_adj_log_pdf(x.ad(pack),
150 template <
class PtrPackType>
151 void bind(
const PtrPackType& pack)
153 static_cast<void>(pack);
154 if constexpr (mean_t::has_param) {
157 if constexpr (sigma_t::has_param) {
164 mean_.activate_refcnt();
165 sigma_.activate_refcnt();
168 template <
class XType,
class GenType>
169 bool prune(XType&, GenType&)
const {
return false; }
184 template <
class MeanType,
class SDType
185 ,
class = std::enable_if_t<
186 util::is_valid_dist_param_v<MeanType> &&
187 util::is_valid_dist_param_v<SDType>
189 inline constexpr
auto normal(
const MeanType& mean_expr,
190 const SDType& sd_expr)
195 mean_t wrap_mean_expr = mean_expr;
196 sd_t wrap_sd_expr = sd_expr;
203 #undef PPL_NORMAL_PARAM_SHAPE
void activate_refcnt() const
Definition: normal.hpp:162
Normal(const mean_t &mean, const sigma_t &sigma)
Definition: normal.hpp:114
#define PPL_NORMAL_PARAM_SHAPE
Definition: normal.hpp:8
static constexpr bool value
Definition: normal.hpp:24
Definition: normal.hpp:52
void bind(const PtrPackType &pack)
Definition: normal.hpp:151
dist_value_t normal_pdf(const XType &x, const MeanType &mean, const SigmaType &sigma)
Definition: density.hpp:40
Definition: dist_expr_traits.hpp:24
constexpr auto normal(const MeanType &mean_expr, const SDType &sd_expr)
Definition: normal.hpp:189
static constexpr bool value
Definition: normal.hpp:53
util::cont_param_t value_t
Definition: normal.hpp:110
Definition: normal.hpp:97
Definition: normal.hpp:38
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
dist_value_t log_pdf(const XType &x)
Definition: normal.hpp:129
constexpr bool normal_valid_param_dim_case_1_v
Definition: normal.hpp:65
Definition: normal.hpp:23
constexpr bool normal_valid_param_dim_case_2_v
Definition: normal.hpp:70
constexpr bool normal_valid_dim_v
Definition: normal.hpp:76
dist_value_t normal_log_pdf(const XType &x, const MeanType &mean, const SigmaType &sigma)
Definition: density.hpp:149
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
double cont_param_t
Definition: dist_expr_traits.hpp:14
static constexpr bool value
Definition: normal.hpp:39
Definition: bounded.hpp:11
bool prune(XType &, GenType &) const
Definition: normal.hpp:169
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: normal.hpp:139
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
dist_value_t pdf(const XType &x)
Definition: normal.hpp:120