autoppl  v0.8
A C++ template library for probabilistic programming
cauchy.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <cassert>
3 #include <random>
4 #include <fastad_bits/reverse/stat/cauchy.hpp>
8 #include <autoppl/math/math.hpp>
9 
10 #define PPL_CAUCHY_PARAM_SHAPE \
11  "Cauchy parameters loc and scale must be either scalar or vector. "
12 
13 namespace ppl {
14 namespace expr {
15 namespace dist {
16 namespace details {
17 
22 template <class LocType
23  , class ScaleType>
25 {
26  static constexpr bool value =
27  util::is_shape_v<LocType> &&
28  util::is_shape_v<ScaleType> &&
29  !util::is_mat_v<LocType> &&
30  !util::is_mat_v<ScaleType>;
31 };
32 
37 template <class VarType
38  , class LocType
39  , class ScaleType>
41 {
42  static constexpr bool value =
43  util::is_shape_v<VarType> &&
44  (
45  (util::is_scl_v<VarType> &&
46  util::is_scl_v<LocType> &&
47  util::is_scl_v<ScaleType>) ||
48  (util::is_vec_v<VarType> &&
50  );
51 };
52 
53 template <class LocType
54  , class ScaleType>
55 inline constexpr bool cauchy_valid_param_dim_v =
57 
58 template <class VarType
59  , class LocType
60  , class ScaleType>
61 inline constexpr bool cauchy_valid_dim_v =
63 
64 } // namespace details
65 
78 template <class LocType
79  , class ScaleType>
80 struct Cauchy: util::DistExprBase<Cauchy<LocType, ScaleType>>
81 {
82 private:
83  using loc_t = LocType;
84  using scale_t = ScaleType;
85 
86  static_assert(util::is_var_expr_v<loc_t>);
87  static_assert(util::is_var_expr_v<scale_t>);
88  static_assert(details::cauchy_valid_param_dim_v<loc_t, scale_t>,
91  );
92 
93 public:
96  using typename base_t::dist_value_t;
97 
98  Cauchy(const loc_t& loc,
99  const scale_t& scale)
100  : loc_{loc}, scale_{scale}
101  {}
102 
103  //template <class XType>
104  //dist_value_t pdf(const XType& x)
105  //{
106  // static_assert(util::is_dist_assignable_v<XType>);
107  // static_assert(details::cauchy_valid_dim_v<XType, loc_t, scale_t>,
108  // PPL_DIST_SHAPE_MISMATCH);
109  // return math::cauchy_pdf(x.get(), loc_.eval(), scale_.eval());
110  //}
111 
112  template <class XType>
113  dist_value_t log_pdf(const XType& x)
114  {
115  static_assert(util::is_dist_assignable_v<XType>);
116  static_assert(details::cauchy_valid_dim_v<XType, loc_t, scale_t>,
118  return math::cauchy_log_pdf(x.get(), loc_.eval(), scale_.eval());
119  }
120 
121  template <class XType
122  , class PtrPackType>
123  auto ad_log_pdf(const XType& x,
124  const PtrPackType& pack) const
125  {
126  return ad::cauchy_adj_log_pdf(x.ad(pack),
127  loc_.ad(pack),
128  scale_.ad(pack));
129  }
130 
131  template <class PtrPackType>
132  void bind(const PtrPackType& pack)
133  {
134  static_cast<void>(pack);
135  if constexpr (loc_t::has_param) {
136  loc_.bind(pack);
137  }
138  if constexpr (scale_t::has_param) {
139  scale_.bind(pack);
140  }
141  }
142 
143  void activate_refcnt() const
144  {
145  loc_.activate_refcnt();
146  scale_.activate_refcnt();
147  }
148 
149  template <class XType, class GenType>
150  constexpr bool prune(XType&, GenType&) const { return false; }
151 
152 private:
153  loc_t loc_;
154  scale_t scale_;
155 };
156 
157 } // namespace dist
158 } // namespace expr
159 
165 template <class LocType, class ScaleType
166  , class = std::enable_if_t<
167  util::is_valid_dist_param_v<LocType> &&
168  util::is_valid_dist_param_v<ScaleType>
169  > >
170 inline constexpr auto cauchy(const LocType& loc_expr,
171  const ScaleType& scale_expr)
172 {
173  using loc_t = util::convert_to_param_t<LocType>;
174  using scale_t = util::convert_to_param_t<ScaleType>;
175 
176  loc_t wrap_loc_expr = loc_expr;
177  scale_t wrap_scale_expr = scale_expr;
178 
179  return expr::dist::Cauchy(wrap_loc_expr, wrap_scale_expr);
180 }
181 
182 } // namespace ppl
183 
184 #undef PPL_CAUCHY_PARAM_SHAPE
ppl::expr::dist::details::cauchy_valid_param_dim::value
static constexpr bool value
Definition: cauchy.hpp:26
ppl::expr::dist::Cauchy::ad_log_pdf
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: cauchy.hpp:123
ppl::util::DistExprBase
Definition: dist_expr_traits.hpp:24
ppl::expr::dist::Cauchy
Definition: cauchy.hpp:81
ppl::expr::dist::Cauchy::prune
constexpr bool prune(XType &, GenType &) const
Definition: cauchy.hpp:150
ppl::expr::dist::details::cauchy_valid_param_dim
Definition: cauchy.hpp:25
ppl::expr::dist::details::cauchy_valid_param_dim_v
constexpr bool cauchy_valid_param_dim_v
Definition: cauchy.hpp:55
ppl::expr::dist::details::cauchy_valid_dim_v
constexpr bool cauchy_valid_dim_v
Definition: cauchy.hpp:61
ppl::expr::dist::Cauchy::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::expr::dist::Cauchy::activate_refcnt
void activate_refcnt() const
Definition: cauchy.hpp:143
PPL_CAUCHY_PARAM_SHAPE
#define PPL_CAUCHY_PARAM_SHAPE
Definition: cauchy.hpp:10
ppl::cauchy
constexpr auto cauchy(const LocType &loc_expr, const ScaleType &scale_expr)
Definition: cauchy.hpp:170
ppl::util::DistExprBase::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::util::cont_param_t
double cont_param_t
Definition: dist_expr_traits.hpp:14
dist_utils.hpp
ppl::math::cauchy_log_pdf
dist_value_t cauchy_log_pdf(const XType &x, const LocType &loc, const ScaleType &scale)
Definition: density.hpp:261
ppl::expr::dist::Cauchy::Cauchy
Cauchy(const loc_t &loc, const scale_t &scale)
Definition: cauchy.hpp:98
ppl
Definition: bounded.hpp:11
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::dist::details::cauchy_valid_dim::value
static constexpr bool value
Definition: cauchy.hpp:42
ppl::expr::dist::Cauchy::log_pdf
dist_value_t log_pdf(const XType &x)
Definition: cauchy.hpp:113
density.hpp
ppl::expr::dist::Cauchy::bind
void bind(const PtrPackType &pack)
Definition: cauchy.hpp:132
traits.hpp
math.hpp
ppl::expr::dist::Cauchy::value_t
util::cont_param_t value_t
Definition: cauchy.hpp:94
PPL_DIST_SHAPE_MISMATCH
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
ppl::expr::dist::details::cauchy_valid_dim
Definition: cauchy.hpp:41