autoppl  v0.8
A C++ template library for probabilistic programming
wishart.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/stat/wishart.hpp>
6 #include <autoppl/math/math.hpp>
7 
8 #define PPL_WISHART_PARAM_SHAPE \
9  "Wishart distribution scale matrix must have shape mat or selfadjmat " \
10  "and n must be a scalar. "
11 
12 namespace ppl {
13 namespace expr {
14 namespace dist {
15 
24 template <class VType
25  , class NType>
26 struct Wishart:
27  util::DistExprBase<Wishart<VType, NType>>
28 {
29 private:
30  using v_t = VType;
31  using n_t = NType;
32 
33  static_assert(util::is_var_expr_v<v_t>);
34  static_assert(util::is_var_expr_v<n_t>);
35  static_assert(util::is_mat_v<v_t> &&
36  util::is_scl_v<n_t>,
39  );
40 
41 public:
44  using typename base_t::dist_value_t;
45 
46  Wishart(const v_t& v,
47  const n_t& n)
48  : v_{v}, n_{n}
49  {}
50 
51  //template <class XType>
52  //dist_value_t pdf(const XType& x) const
53  //{
54  // static_assert(util::is_var_v<XType>);
55  // static_assert(util::is_mat_v<XType>,
56  // PPL_DIST_SHAPE_MISMATCH);
57  // return math::wishart_pdf(x.get(), v_.eval(), n_.eval());
58  //}
59 
60  template <class XType>
61  dist_value_t log_pdf(const XType& x) const
62  {
63  static_assert(util::is_dist_assignable_v<XType>);
64  static_assert(util::is_mat_v<XType>,
66  return math::wishart_log_pdf(x.get(), v_.eval(), n_.eval());
67  }
68 
69  template <class XType
70  , class PtrPackType>
71  auto ad_log_pdf(const XType& x,
72  const PtrPackType& pack) const
73  {
74  static_assert(util::is_dist_assignable_v<XType>);
75  static_assert(util::is_mat_v<XType>,
77  return ad::wishart_adj_log_pdf(x.ad(pack),
78  v_.ad(pack),
79  n_.ad(pack));
80  }
81 
82  template <class PtrPackType>
83  void bind(const PtrPackType& pack)
84  {
85  static_cast<void>(pack);
86  if constexpr (v_t::has_param) {
87  v_.bind(pack);
88  }
89  if constexpr (n_t::has_param) {
90  n_.bind(pack);
91  }
92  }
93 
94  void activate_refcnt() const
95  {
96  v_.activate_refcnt();
97  n_.activate_refcnt();
98  }
99 
100  template <class XType, class GenType>
101  bool prune(XType&, GenType&) const { return false; }
102 
103 private:
104  v_t v_;
105  n_t n_;
106 };
107 
108 } // namespace dist
109 } // namespace expr
110 
116 template <class VType, class NType
117  , class = std::enable_if_t<
118  util::is_valid_dist_param_v<VType> &&
119  util::is_valid_dist_param_v<NType>
120  > >
121 inline constexpr auto wishart(const VType& v_expr,
122  const NType& n_expr)
123 {
126 
127  v_t wrap_v_expr = v_expr;
128  n_t wrap_n_expr = n_expr;
129 
130  return expr::dist::Wishart(wrap_v_expr, wrap_n_expr);
131 }
132 
133 } // namespace ppl
134 
135 #undef PPL_WISHART_PARAM_SHAPE
ppl::expr::dist::Wishart::Wishart
Wishart(const v_t &v, const n_t &n)
Definition: wishart.hpp:46
ppl::expr::dist::Wishart::activate_refcnt
void activate_refcnt() const
Definition: wishart.hpp:94
ppl::util::DistExprBase
Definition: dist_expr_traits.hpp:24
ppl::expr::dist::Wishart::bind
void bind(const PtrPackType &pack)
Definition: wishart.hpp:83
ppl::expr::dist::Wishart::prune
bool prune(XType &, GenType &) const
Definition: wishart.hpp:101
PPL_WISHART_PARAM_SHAPE
#define PPL_WISHART_PARAM_SHAPE
Definition: wishart.hpp:8
ppl::wishart
constexpr auto wishart(const VType &v_expr, const NType &n_expr)
Definition: wishart.hpp:121
ppl::expr::dist::Wishart::log_pdf
dist_value_t log_pdf(const XType &x) const
Definition: wishart.hpp:61
ppl::math::wishart_log_pdf
dist_value_t wishart_log_pdf(const Eigen::MatrixBase< XType > &x, const Eigen::MatrixBase< VType > &v, const NType &n)
Definition: density.hpp:608
ppl::util::DistExprBase::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
ppl::util::cont_param_t
double cont_param_t
Definition: dist_expr_traits.hpp:14
dist_utils.hpp
ppl::expr::dist::Wishart
Definition: wishart.hpp:28
ppl
Definition: bounded.hpp:11
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::dist::Wishart::ad_log_pdf
auto ad_log_pdf(const XType &x, const PtrPackType &pack) const
Definition: wishart.hpp:71
density.hpp
traits.hpp
math.hpp
ppl::expr::dist::Wishart::dist_value_t
util::dist_value_t dist_value_t
Definition: dist_expr_traits.hpp:26
PPL_DIST_SHAPE_MISMATCH
#define PPL_DIST_SHAPE_MISMATCH
Definition: dist_utils.hpp:2
ppl::expr::dist::Wishart::value_t
util::cont_param_t value_t
Definition: wishart.hpp:42