autoppl  v0.8
A C++ template library for probabilistic programming
uniform.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <cassert>
3 #include <random>
4 #include <fastad_bits/reverse/stat/uniform.hpp>
8 #include <autoppl/math/math.hpp>
9 
10 #define PPL_UNIFORM_PARAM_SHAPE \
11  "Uniform parameters min and max must be either scalar or vector. "
12 
13 namespace ppl {
14 namespace expr {
15 namespace dist {
16 namespace details {
17 
22 template <class MinType
23  , class MaxType>
25 {
26  static constexpr bool value =
27  util::is_shape_v<MinType> &&
28  util::is_shape_v<MaxType> &&
29  !util::is_mat_v<MinType> &&
30  !util::is_mat_v<MaxType>;
31 };
32 
37 template <class VarType
38  , class MinType
39  , class MaxType>
41 {
42  static constexpr bool value =
43  util::is_shape_v<VarType> &&
44  (
45  (util::is_scl_v<VarType> &&
46  util::is_scl_v<MinType> &&
47  util::is_scl_v<MaxType>) ||
48  (util::is_vec_v<VarType> &&
50  );
51 };
52 
53 template <class MinType
54  , class MaxType>
55 inline constexpr bool uniform_valid_param_dim_v =
57 
58 template <class VarType
59  , class MinType
60  , class MaxType>
61 inline constexpr bool uniform_valid_dim_v =
63 
64 } // namespace details
65 
78 template <class MinType
79  , class MaxType>
80 struct Uniform: util::DistExprBase<Uniform<MinType, MaxType>>
81 {
82 private:
83  using min_t = MinType;
84  using max_t = MaxType;
85 
86  static_assert(util::is_var_expr_v<min_t>);
87  static_assert(util::is_var_expr_v<max_t>);
88  static_assert(details::uniform_valid_param_dim_v<min_t, max_t>,
91  );
92 
93 public:
96  using typename base_t::dist_value_t;
97 
98  Uniform(const min_t& min,
99  const max_t& max)
100  : min_{min}, max_{max}
101  {}
102 
103  template <class XType>
104  dist_value_t pdf(const XType& x)
105  {
106  static_assert(util::is_dist_assignable_v<XType>);
107  static_assert(details::uniform_valid_dim_v<XType, min_t, max_t>,
109  return math::uniform_pdf(x.get(), min_.eval(), max_.eval());
110  }
111 
112  template <class XType>
113  dist_value_t log_pdf(const XType& x)
114  {
115  static_assert(util::is_dist_assignable_v<XType>);
116  static_assert(details::uniform_valid_dim_v<XType, min_t, max_t>,
118  return math::uniform_log_pdf(x.get(), min_.eval(), max_.eval());
119  }
120 
121  template <class XType
122  , class PtrPackType>
123  auto ad_log_pdf(const XType& x,
124  const PtrPackType& pack) const
125  {
126  return ad::uniform_adj_log_pdf(x.ad(pack),
127  min_.ad(pack),
128  max_.ad(pack));
129  }
130 
131  template <class PtrPackType>
132  void bind(const PtrPackType& pack)
133  {
134  static_cast<void>(pack);
135  if constexpr (min_t::has_param) {
136  min_.bind(pack);
137  }
138  if constexpr (max_t::has_param) {
139  max_.bind(pack);
140  }
141  }
142 
143  void activate_refcnt() const
144  {
145  min_.activate_refcnt();
146  max_.activate_refcnt();
147  }
148 
149  // Note: assumes that min_ and max_ have already been evaluated!
150  template <class XType, class GenType>
151  bool prune(XType& x, GenType& gen) const {
152  using x_t = std::decay_t<XType>;
153  static_assert(util::is_param_v<x_t>);
154 
155  auto m = min_.get();
156  auto M = max_.get();
157  std::uniform_real_distribution<dist_value_t> dist(0.,1.);
158 
159  if constexpr (util::is_scl_v<x_t>) {
160  bool needs_prune = (x.get() <= m) || (x.get() >= M);
161  if (needs_prune) {
162  x.get() = (M-m) * dist(gen) + m;
163  }
164  return needs_prune;
165 
166  } else if constexpr (util::is_vec_v<x_t>) {
167  auto get = [](const auto& v, size_t i=0, size_t j=0) {
168  using v_t = std::decay_t<decltype(v)>;
169  static_cast<void>(i);
170  static_cast<void>(j);
171  if constexpr (!ad::util::is_eigen_v<v_t>) {
172  return v;
173  } else {
174  return v(i,j);
175  }
176  };
177  auto to_array = [](const auto& v) {
178  using v_t = std::decay_t<decltype(v)>;
179  if constexpr (!ad::util::is_eigen_v<v_t>) {
180  return v;
181  } else {
182  return v.array();
183  }
184  };
185 
186  auto xa = x.get().array();
187  bool needs_prune = (xa <= to_array(m)).max(xa >= to_array(M)).any();
188  if (needs_prune) {
189  using vec_t = std::decay_t<decltype(x.get())>;
190  x.get() = vec_t::NullaryExpr(x.get().size(),
191  [&](size_t i) {
192  return (get(M, i) - get(m, i)) * dist(gen) + get(m, i);
193  });
194  }
195  return needs_prune;
196 
197  } else {
198  static_assert(util::is_scl_v<x_t> ||
199  util::is_vec_v<x_t>,
200  "x must be a scalar or vector shape.");
201  }
202  }
203 
204 private:
205  min_t min_;
206  max_t max_;
207 };
208 
209 } // namespace dist
210 } // namespace expr
211 
217 template <class MinType, class MaxType
218  , class = std::enable_if_t<
219  util::is_valid_dist_param_v<MinType> &&
220  util::is_valid_dist_param_v<MaxType>
221  > >
222 inline constexpr auto uniform(const MinType& min_expr,
223  const MaxType& max_expr)
224 {
225  using min_t = util::convert_to_param_t<MinType>;
226  using max_t = util::convert_to_param_t<MaxType>;
227 
228  min_t wrap_min_expr = min_expr;
229  max_t wrap_max_expr = max_expr;
230 
231  return expr::dist::Uniform(wrap_min_expr, wrap_max_expr);
232 }
233 
234 } // namespace ppl
235 
236 #undef PPL_UNIFORM_PARAM_SHAPE
ppl::expr::dist::details::uniform_valid_param_dim::value
static constexpr bool value
Definition: uniform.hpp:26
ppl::expr::dist::details::uniform_valid_param_dim
Definition: uniform.hpp:25
ppl::expr::dist::Uniform::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::util::get
auto & get(T &&x)
Definition: value.hpp:52
PPL_UNIFORM_PARAM_SHAPE
#define PPL_UNIFORM_PARAM_SHAPE
Definition: uniform.hpp:10
ppl::util::DistExprBase
Definition: dist_expr_traits.hpp:24
ppl::expr::dist::details::uniform_valid_dim::value
static constexpr bool value
Definition: uniform.hpp:42
ppl::uniform
constexpr auto uniform(const MinType &min_expr, const MaxType &max_expr)
Definition: uniform.hpp:222
ppl::expr::dist::Uniform::Uniform
Uniform(const min_t &min, const max_t &max)
Definition: uniform.hpp:98
ppl::expr::dist::Uniform::value_t
util::cont_param_t value_t
Definition: uniform.hpp:94
ppl::expr::dist::details::uniform_valid_dim
Definition: uniform.hpp:41
ppl::math::uniform_log_pdf
dist_value_t uniform_log_pdf(const XType &x, const MinType &min, const MaxType &max)
Definition: density.hpp:421
ppl::expr::dist::Uniform
Definition: uniform.hpp:81
ppl::expr::dist::Uniform::activate_refcnt
void activate_refcnt() const
Definition: uniform.hpp:143
ppl::expr::dist::Uniform::prune
bool prune(XType &x, GenType &gen) const
Definition: uniform.hpp:151
ppl::expr::dist::Uniform::pdf
dist_value_t pdf(const XType &x)
Definition: uniform.hpp:104
ppl::expr::dist::Uniform::ad_log_pdf
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: uniform.hpp:123
ppl::util::DistExprBase::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::util::cont_param_t
double cont_param_t
Definition: dist_expr_traits.hpp:14
dist_utils.hpp
ppl::util::to_array
constexpr auto to_array(const T &x)
Definition: value.hpp:74
ppl
Definition: bounded.hpp:11
ppl::math::uniform_pdf
dist_value_t uniform_pdf(const XType &x, const MinType &min, const MaxType &max)
Definition: density.hpp:347
ppl::expr::dist::details::uniform_valid_param_dim_v
constexpr bool uniform_valid_param_dim_v
Definition: uniform.hpp:55
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::dist::Uniform::bind
void bind(const PtrPackType &pack)
Definition: uniform.hpp:132
ppl::expr::dist::details::uniform_valid_dim_v
constexpr bool uniform_valid_dim_v
Definition: uniform.hpp:61
density.hpp
traits.hpp
math.hpp
PPL_DIST_SHAPE_MISMATCH
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
ppl::expr::dist::Uniform::log_pdf
dist_value_t log_pdf(const XType &x)
Definition: uniform.hpp:113