autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
10 template <
class LHSType
12 constexpr
static LHSType&
eval(LHSType& lhs,
const RHSType& rhs)
15 template <
class LHSType
17 constexpr
static auto eval(
const LHSType& lhs,
const RHSType& rhs)
23 template <
class LHSType
25 constexpr
static LHSType&
eval(LHSType& lhs,
const RHSType& rhs)
26 {
return lhs += rhs; }
28 template <
class LHSType
30 constexpr
static auto eval(
const LHSType& lhs,
const RHSType& rhs)
31 {
return lhs += rhs; }
36 template <
class LHSType
38 constexpr
static LHSType&
eval(LHSType& lhs,
const RHSType& rhs)
39 {
return lhs -= rhs; }
41 template <
class LHSType
43 constexpr
static auto eval(
const LHSType& lhs,
const RHSType& rhs)
44 {
return lhs -= rhs; }
49 template <
class LHSType
51 constexpr
static LHSType&
eval(LHSType& lhs,
const RHSType& rhs)
52 {
return lhs *= rhs; }
54 template <
class LHSType
56 constexpr
static auto eval(
const LHSType& lhs,
const RHSType& rhs)
57 {
return lhs *= rhs; }
62 template <
class LHSType
64 constexpr
static LHSType&
eval(LHSType& lhs,
const RHSType& rhs)
65 {
return lhs /= rhs; }
67 template <
class LHSType
69 constexpr
static auto eval(
const LHSType& lhs,
const RHSType& rhs)
70 {
return lhs /= rhs; }
74 ,
class TParamViewType
81 using tp_view_t = TParamViewType;
82 using var_expr_t = VarExprType;
84 static_assert(util::is_tparam_v<tp_view_t>);
85 static_assert(util::is_var_expr_v<var_expr_t>);
87 static_assert(std::is_same_v<
91 static_assert(util::is_scl_v<var_expr_t> ||
102 const var_expr_t& expr)
103 : tp_view_{tp_view}, expr_{expr}
106 template <
class Func>
109 static_cast<void>(f);
110 if constexpr (std::is_same_v<Op, Eq>) {
115 template <
class Func>
118 static_cast<void>(f);
119 if constexpr (std::is_same_v<Op, Eq>) {
124 auto get()
const {
return tp_view_.get(); }
127 if constexpr (util::is_scl_v<tp_view_t> &&
128 util::is_scl_v<var_expr_t>) {
129 return op_t::eval(tp_view_.get(), expr_.eval());
130 }
else if constexpr (!util::is_scl_v<tp_view_t> &&
131 util::is_scl_v<var_expr_t>) {
132 auto tpa = tp_view_.get().array();
133 return op_t::eval(tpa, expr_.eval());
135 auto tpa = tp_view_.get().array();
136 return op_t::eval(tpa, expr_.eval().array());
140 constexpr
size_t size()
const {
return tp_view_.size(); }
141 constexpr
size_t rows()
const {
return tp_view_.rows(); }
142 constexpr
size_t cols()
const {
return tp_view_.cols(); }
144 template <
class PtrPackType>
145 auto ad(
const PtrPackType& pack)
const
147 return op_t::eval(tp_view_.ad(pack), expr_.ad(pack));
150 template <
class PtrPackType>
151 void bind(
const PtrPackType& pack)
153 if constexpr (tp_view_t::has_param) {
156 if constexpr (var_expr_t::has_param) {
162 tp_view_.activate_refcnt();
163 expr_.activate_refcnt();
177 ,
class TParamViewType
180 const VarExprType& expr)
185 tp_view_t wrap_tp_view = tp_view;
186 expr_t wrap_expr = expr;
void traverse(Func &&f)
Definition: op_eq.hpp:107
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:17
constexpr auto opeq_helper(const TParamViewType &tp_view, const VarExprType &expr)
Definition: op_eq.hpp:179
auto eval()
Definition: op_eq.hpp:126
void traverse(Func &&f) const
Definition: op_eq.hpp:116
constexpr size_t size() const
Definition: op_eq.hpp:140
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:56
static constexpr bool has_param
Definition: op_eq.hpp:99
auto ad(const PtrPackType &pack) const
Definition: op_eq.hpp:145
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:64
OpEqNode(const tp_view_t &tp_view, const var_expr_t &expr)
Definition: op_eq.hpp:101
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:25
typename util::shape_traits< tp_view_t >::shape_t shape_t
Definition: op_eq.hpp:98
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:43
auto get() const
Definition: op_eq.hpp:124
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
constexpr size_t rows() const
Definition: op_eq.hpp:141
const auto & get_variable() const
Definition: op_eq.hpp:167
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:69
void activate_refcnt() const
Definition: op_eq.hpp:161
Definition: bounded.hpp:11
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:51
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:12
constexpr size_t cols() const
Definition: op_eq.hpp:142
Definition: var_expr_traits.hpp:20
typename util::var_expr_traits< tp_view_t >::value_t value_t
Definition: op_eq.hpp:97
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
auto & get_variable()
Definition: op_eq.hpp:166
constexpr static LHSType & eval(LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:38
void bind(const PtrPackType &pack)
Definition: op_eq.hpp:151
constexpr static auto eval(const LHSType &lhs, const RHSType &rhs)
Definition: op_eq.hpp:30