5 #include <fastad_bits/util/type_traits.hpp>
10 #define M_PI 3.14159265358979323846
23 2.506628274631000502415765284811045;
25 0.918938533204672741780329736405617;
35 ,
class = std::enable_if_t<
36 std::is_arithmetic_v<XType> &&
37 std::is_arithmetic_v<MeanType> &&
38 std::is_arithmetic_v<SigmaType>
42 const SigmaType& sigma)
44 if (sigma <= 0)
return math::neg_inf<dist_value_t>;
46 return std::exp(-0.5 * z * z) / (sigma *
SQRT_TWO_PI);
53 ,
class = std::enable_if_t<
54 std::is_arithmetic_v<MeanType> &&
55 std::is_arithmetic_v<SigmaType>
59 const SigmaType& sigma)
61 if (sigma <= 0)
return math::neg_inf<dist_value_t>;
62 dist_value_t z_sq = (x.array() - mean).matrix().squaredNorm() / (sigma * sigma);
63 return std::exp(-0.5 * z_sq) / std::pow(sigma *
SQRT_TWO_PI, x.size());
70 ,
class = std::enable_if_t<
71 std::is_arithmetic_v<SigmaType>
74 const Eigen::MatrixBase<MeanType>& mean,
75 const SigmaType& sigma)
77 static_assert(ad::util::is_eigen_vector_v<MeanType>);
78 assert(x.size() == mean.size());
79 if (sigma <= 0)
return math::neg_inf<dist_value_t>;
80 dist_value_t z_sq = (x.array() - mean.array()).matrix().squaredNorm() / (sigma * sigma);
81 return std::exp(-0.5 * z_sq) / std::pow(sigma *
SQRT_TWO_PI, x.size());
88 ,
class = std::enable_if_t<
89 std::is_arithmetic_v<MeanType>
93 const Eigen::MatrixBase<SigmaType>& sigma)
95 if constexpr (ad::util::is_eigen_vector_v<SigmaType>) {
96 assert(x.size() == sigma.size());
97 if ((sigma.array() <= 0).any())
return math::neg_inf<dist_value_t>;
98 dist_value_t z_sq = ((x.array() - mean)/sigma.array()).matrix().squaredNorm();
99 return std::exp(-0.5 * z_sq) / (std::pow(
SQRT_TWO_PI, x.size()) * sigma.array().prod());
101 }
else if constexpr (ad::util::is_eigen_matrix_v<SigmaType>) {
102 assert(x.size() == sigma.rows());
103 Eigen::LLT<Eigen::MatrixXd> llt(sigma);
104 if (llt.info() != Eigen::Success)
return math::neg_inf<dist_value_t>;
105 dist_value_t z_sq = llt.matrixL().solve((x.array() - mean).matrix()).squaredNorm();
107 return std::exp(-0.5 * z_sq) / (std::pow(
SQRT_TWO_PI, x.size()) * det);
112 template <
class XType
116 const Eigen::MatrixBase<MeanType>& mean,
117 const Eigen::MatrixBase<SigmaType>& sigma)
119 static_assert(ad::util::is_eigen_vector_v<MeanType>);
121 if constexpr (ad::util::is_eigen_vector_v<SigmaType>) {
122 assert(x.size() == sigma.size());
123 assert(x.size() == mean.size());
124 if ((sigma.array() <= 0).any())
return math::neg_inf<dist_value_t>;
125 dist_value_t z_sq = ((x.array() - mean.array())/sigma.array()).matrix().squaredNorm();
126 return std::exp(-0.5 * z_sq) / (std::pow(
SQRT_TWO_PI, x.size()) * sigma.array().prod());
129 else if constexpr (ad::util::is_eigen_matrix_v<SigmaType>) {
130 assert(x.size() == sigma.rows());
131 assert(x.size() == mean.size());
132 Eigen::LLT<Eigen::MatrixXd> llt(sigma);
133 if (llt.info() != Eigen::Success)
return math::neg_inf<dist_value_t>;
134 dist_value_t z_sq = llt.matrixL().solve((x.array() - mean.array()).matrix()).squaredNorm();
136 return std::exp(-0.5 * z_sq) / (std::pow(
SQRT_TWO_PI, x.size()) * det);
141 template <
class XType
144 ,
class = std::enable_if_t<
145 std::is_arithmetic_v<XType> &&
146 std::is_arithmetic_v<MeanType> &&
147 std::is_arithmetic_v<SigmaType>
150 const MeanType& mean,
151 const SigmaType& sigma)
153 if (sigma <= 0)
return math::neg_inf<dist_value_t>;
159 template <
class XType
162 ,
class = std::enable_if_t<
163 std::is_arithmetic_v<MeanType> &&
164 std::is_arithmetic_v<SigmaType>
167 const MeanType& mean,
168 const SigmaType& sigma)
170 if (sigma <= 0)
return math::neg_inf<dist_value_t>;
171 dist_value_t z_sq = (x.array() - mean).matrix().squaredNorm() / (sigma * sigma);
176 template <
class XType
179 ,
class = std::enable_if_t<
180 std::is_arithmetic_v<SigmaType>
183 const Eigen::MatrixBase<MeanType>& mean,
184 const SigmaType& sigma)
186 static_assert(ad::util::is_eigen_vector_v<MeanType>);
187 assert(x.size() == mean.size());
188 if (sigma <= 0)
return math::neg_inf<dist_value_t>;
189 dist_value_t z_sq = (x.array() - mean.array()).matrix().squaredNorm() / (sigma * sigma);
194 template <
class XType
197 ,
class = std::enable_if_t<
198 std::is_arithmetic_v<MeanType>
201 const MeanType& mean,
202 const Eigen::MatrixBase<SigmaType>& sigma)
204 if constexpr (ad::util::is_eigen_vector_v<SigmaType>) {
205 assert(x.size() == sigma.size());
206 if ((sigma.array() <= 0).any())
return math::neg_inf<dist_value_t>;
207 dist_value_t z_sq = ((x.array() - mean)/sigma.array()).matrix().squaredNorm();
208 return -0.5 * z_sq - (x.size() *
LOG_SQRT_TWO_PI) - std::log(sigma.array().prod());
211 else if constexpr (ad::util::is_eigen_matrix_v<SigmaType>) {
212 assert(x.size() == sigma.rows());
213 Eigen::LLT<Eigen::MatrixXd> llt(sigma);
214 if (llt.info() != Eigen::Success)
return math::neg_inf<dist_value_t>;
215 dist_value_t z_sq = llt.matrixL().solve((x.array() - mean).matrix()).squaredNorm();
222 template <
class XType
226 const Eigen::MatrixBase<MeanType>& mean,
227 const Eigen::MatrixBase<SigmaType>& sigma)
229 if constexpr (ad::util::is_eigen_vector_v<SigmaType>) {
230 assert(x.size() == sigma.size());
231 assert(x.size() == mean.size());
232 if ((sigma.array() <= 0).any())
return math::neg_inf<dist_value_t>;
233 dist_value_t z_sq = ((x.array() - mean.array())/sigma.array()).matrix().squaredNorm();
234 return -0.5 * z_sq - (x.size() *
LOG_SQRT_TWO_PI) - std::log(sigma.array().prod());
237 else if constexpr (ad::util::is_eigen_matrix_v<SigmaType>) {
238 assert(x.size() == sigma.rows());
239 assert(x.size() == mean.size());
240 Eigen::LLT<Eigen::MatrixXd> llt(sigma);
241 if (llt.info() != Eigen::Success)
return math::neg_inf<dist_value_t>;
242 dist_value_t z_sq = llt.matrixL().solve((x.array() - mean.array()).matrix()).squaredNorm();
253 template <
class XType
256 ,
class = std::enable_if_t<
257 std::is_arithmetic_v<XType> &&
258 std::is_arithmetic_v<LocType> &&
259 std::is_arithmetic_v<ScaleType>
263 const ScaleType& scale)
266 return (scale > 0) ? -std::log(scale + diff * diff / scale) : neg_inf<double>;
270 template <
class XType
273 ,
class = std::enable_if_t<
274 std::is_arithmetic_v<LocType> &&
275 std::is_arithmetic_v<ScaleType>
279 const ScaleType& scale)
281 bool cond = scale > 0.;
282 auto diff = x.array() - loc;
283 return cond ? -(scale + (1./scale) * diff * diff).log().sum() :
288 template <
class XType
291 ,
class = std::enable_if_t<
292 std::is_arithmetic_v<ScaleType>
295 const Eigen::MatrixBase<LocType>& loc,
296 const ScaleType& scale)
298 bool cond = scale > 0.;
299 auto diff = x.array() - loc.array();
300 return cond ? -(scale + (1./scale) * diff * diff).log().sum() : neg_inf<double>;
304 template <
class XType
307 ,
class = std::enable_if_t<
308 std::is_arithmetic_v<LocType>
312 const Eigen::MatrixBase<ScaleType>& scale)
314 bool cond = (scale.array() > 0.).all();
315 auto diff = x.array() - loc;
316 auto gamma = scale.array();
317 return cond ? -(gamma + (1./gamma) * diff * diff).log().sum() : neg_inf<double>;
321 template <
class XType
325 const Eigen::MatrixBase<LocType>& loc,
326 const Eigen::MatrixBase<ScaleType>& scale)
328 bool cond = (scale.array() > 0.).all();
329 auto diff = x.array() - loc.array();
330 auto gamma = scale.array();
331 return cond ? -(gamma + (1./gamma) * diff * diff).log().sum() : neg_inf<double>;
339 template <
class XType
342 ,
class = std::enable_if_t<
343 std::is_arithmetic_v<XType> &&
344 std::is_arithmetic_v<MinType> &&
345 std::is_arithmetic_v<MaxType>
351 return (min < x && x < max) ? 1. / (max - min) : 0;
355 template <
class XType
358 ,
class = std::enable_if_t<
359 std::is_arithmetic_v<MinType> &&
360 std::is_arithmetic_v<MaxType>
366 bool cond = (min < x.array()).all() && (x.array() < max).all();
367 return cond ? std::pow(1./(max-min), x.size()) : 0;
371 template <
class XType
374 ,
class = std::enable_if_t<
375 std::is_arithmetic_v<MaxType>
378 const Eigen::MatrixBase<MinType>& min,
381 bool cond = (min.array() < x.array()).all() && (x.array() < max).all();
382 return cond ? (1./(max-min.array())).prod() : 0;
386 template <
class XType
389 ,
class = std::enable_if_t<
390 std::is_arithmetic_v<MinType>
394 const Eigen::MatrixBase<MaxType>& max)
396 bool cond = (min < x.array()).all() && (x.array() < max.array()).all();
397 return cond ? (1./(max.array()-min)).prod() : 0;
401 template <
class XType
405 const Eigen::MatrixBase<MinType>& min,
406 const Eigen::MatrixBase<MaxType>& max)
408 bool cond = (min.array() < x.array()).all() && (x.array() < max.array()).all();
409 return cond ? (1./(max.array()-min.array())).prod() : 0;
413 template <
class XType
416 ,
class = std::enable_if_t<
417 std::is_arithmetic_v<XType> &&
418 std::is_arithmetic_v<MinType> &&
419 std::is_arithmetic_v<MaxType>
425 return (min < x && x < max) ? -std::log(max - min) : neg_inf<double>;
429 template <
class XType
432 ,
class = std::enable_if_t<
433 std::is_arithmetic_v<MinType> &&
434 std::is_arithmetic_v<MaxType>
440 bool cond = (min < x.array()).all() && (x.array() < max).all();
442 static_cast<dist_value_t>(x.size())*(-std::log(max-min)) :
447 template <
class XType
450 ,
class = std::enable_if_t<
451 std::is_arithmetic_v<MaxType>
454 const Eigen::MatrixBase<MinType>& min,
457 bool cond = (min.array() < x.array()).all() && (x.array() < max).all();
458 return cond ? -(max-min.array()).log().sum() : neg_inf<double>;
462 template <
class XType
465 ,
class = std::enable_if_t<
466 std::is_arithmetic_v<MinType>
470 const Eigen::MatrixBase<MaxType>& max)
472 bool cond = (min < x.array()).all() && (x.array() < max.array()).all();
473 return cond ? -(max.array()-min).log().sum() : neg_inf<double>;
477 template <
class XType
481 const Eigen::MatrixBase<MinType>& min,
482 const Eigen::MatrixBase<MaxType>& max)
484 bool cond = (min.array() < x.array()).all() && (x.array() < max.array()).all();
485 return cond ? -(max.array()-min.array()).log().sum() : neg_inf<double>;
500 template <
class XType
502 ,
class = std::enable_if_t<
503 std::is_arithmetic_v<XType> &&
504 std::is_arithmetic_v<PType>
509 if (p <= 0)
return (x == 0) + 0.;
510 else if (p >= 1)
return (x == 1) + 0.;
512 if (x == 1)
return p;
513 else if (x == 0)
return 1. - p;
518 template <
class XType
520 ,
class = std::enable_if_t<
521 std::is_arithmetic_v<PType>
527 for (
int i = 0; i < x.size(); ++i) {
534 template <
class XType
537 const Eigen::MatrixBase<PType>& p)
539 assert(x.size() == p.size());
541 for (
int i = 0; i < x.size(); ++i) {
548 template <
class XType
550 ,
class = std::enable_if_t<
551 std::is_arithmetic_v<XType> &&
552 std::is_arithmetic_v<PType>
558 if (x == 0)
return 0.;
559 else return neg_inf<PType>;
562 if (x == 1)
return 0.;
563 else return neg_inf<PType>;
566 if (x == 1)
return std::log(p);
567 else if (x == 0)
return std::log(1. - p);
568 else return neg_inf<PType>;
572 template <
class XType
574 ,
class = std::enable_if_t<
575 std::is_arithmetic_v<PType>
581 for (
int i = 0; i < x.size(); ++i) {
588 template <
class XType
591 const Eigen::MatrixBase<PType>& p)
593 assert(x.size() == p.size());
595 for (
int i = 0; i < x.size(); ++i) {
605 template <
class XType
609 const Eigen::MatrixBase<VType>& v,
612 Eigen::LLT<Eigen::MatrixXd> x_llt(x);
613 Eigen::LLT<Eigen::MatrixXd> v_llt(v);
614 auto log_det_x = std::log(x_llt.matrixL().determinant());
615 auto log_det_v = std::log(v_llt.matrixL().determinant());
616 auto tr = v_llt.solve(x).trace();
618 return (n - p - 1.) * log_det_x - 0.5 * tr - n * log_det_v;