4 #include <fastad_bits/reverse/stat/uniform.hpp>
10 #define PPL_UNIFORM_PARAM_SHAPE \
11 "Uniform parameters min and max must be either scalar or vector. "
22 template <
class MinType
27 util::is_shape_v<MinType> &&
28 util::is_shape_v<MaxType> &&
29 !util::is_mat_v<MinType> &&
30 !util::is_mat_v<MaxType>;
37 template <
class VarType
43 util::is_shape_v<VarType> &&
45 (util::is_scl_v<VarType> &&
46 util::is_scl_v<MinType> &&
47 util::is_scl_v<MaxType>) ||
48 (util::is_vec_v<VarType> &&
53 template <
class MinType
58 template <
class VarType
78 template <
class MinType
83 using min_t = MinType;
84 using max_t = MaxType;
86 static_assert(util::is_var_expr_v<min_t>);
87 static_assert(util::is_var_expr_v<max_t>);
88 static_assert(details::uniform_valid_param_dim_v<min_t, max_t>,
100 : min_{min}, max_{max}
103 template <
class XType>
106 static_assert(util::is_dist_assignable_v<XType>);
107 static_assert(details::uniform_valid_dim_v<XType, min_t, max_t>,
112 template <
class XType>
115 static_assert(util::is_dist_assignable_v<XType>);
116 static_assert(details::uniform_valid_dim_v<XType, min_t, max_t>,
121 template <
class XType
124 const PtrPackType& pack)
const
126 return ad::uniform_adj_log_pdf(x.ad(pack),
131 template <
class PtrPackType>
132 void bind(
const PtrPackType& pack)
134 static_cast<void>(pack);
135 if constexpr (min_t::has_param) {
138 if constexpr (max_t::has_param) {
145 min_.activate_refcnt();
146 max_.activate_refcnt();
150 template <
class XType,
class GenType>
151 bool prune(XType& x, GenType& gen)
const {
152 using x_t = std::decay_t<XType>;
153 static_assert(util::is_param_v<x_t>);
157 std::uniform_real_distribution<dist_value_t> dist(0.,1.);
159 if constexpr (util::is_scl_v<x_t>) {
160 bool needs_prune = (x.get() <= m) || (x.get() >= M);
162 x.get() = (M-m) * dist(gen) + m;
166 }
else if constexpr (util::is_vec_v<x_t>) {
167 auto get = [](
const auto& v,
size_t i=0,
size_t j=0) {
168 using v_t = std::decay_t<decltype(v)>;
169 static_cast<void>(i);
170 static_cast<void>(j);
171 if constexpr (!ad::util::is_eigen_v<v_t>) {
178 using v_t = std::decay_t<decltype(v)>;
179 if constexpr (!ad::util::is_eigen_v<v_t>) {
186 auto xa = x.get().array();
189 using vec_t = std::decay_t<decltype(x.get())>;
190 x.get() = vec_t::NullaryExpr(x.get().size(),
192 return (get(M, i) - get(m, i)) * dist(gen) + get(m, i);
198 static_assert(util::is_scl_v<x_t> ||
200 "x must be a scalar or vector shape.");
217 template <
class MinType,
class MaxType
218 ,
class = std::enable_if_t<
219 util::is_valid_dist_param_v<MinType> &&
220 util::is_valid_dist_param_v<MaxType>
222 inline constexpr
auto uniform(
const MinType& min_expr,
223 const MaxType& max_expr)
228 min_t wrap_min_expr = min_expr;
229 max_t wrap_max_expr = max_expr;
236 #undef PPL_UNIFORM_PARAM_SHAPE