autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
2 #include <fastad_bits/reverse/core/binary.hpp>
5 #define PPL_BINOP_EQUAL_FIXED_SIZE \
6 "If both lhs and rhs are of fixed size, " \
7 "then they must have the same size. "
8 #define PPL_BINOP_NO_MAT_SUPPORT \
9 "Binary operations with matrices are not supported yet. "
31 template <
class BinaryOp
32 ,
class LHSVarExprType
33 ,
class RHSVarExprType>
38 using lhs_t = LHSVarExprType;
39 using rhs_t = RHSVarExprType;
41 static_assert(util::is_var_expr_v<lhs_t>);
42 static_assert(util::is_var_expr_v<rhs_t>);
54 lhs_t::has_param || rhs_t::has_param;
58 : lhs_{lhs}, rhs_{rhs}
65 auto&& lhs = lhs_.get();
66 auto&& rhs = rhs_.get();
67 return eval_helper(lhs, rhs);
72 auto&& lhs = lhs_.eval();
73 auto&& rhs = rhs_.eval();
74 return eval_helper(lhs, rhs);
77 size_t size()
const {
return std::max(lhs_.size(), rhs_.size()); }
78 size_t rows()
const {
return std::max(lhs_.rows(), rhs_.rows()); }
79 size_t cols()
const {
return std::max(lhs_.cols(), rhs_.cols()); }
81 template <
class PtrPackType>
82 auto ad(
const PtrPackType& pack)
const
84 return BinaryOp::fmap(lhs_.ad(pack),
88 template <
class PtrPackType>
89 void bind(
const PtrPackType& pack)
91 if constexpr (lhs_t::has_param) {
94 if constexpr (rhs_t::has_param) {
100 lhs_.activate_refcnt();
101 rhs_.activate_refcnt();
106 template <
class LHSType,
class RHSType>
107 auto eval_helper(
const LHSType& lhs,
108 const RHSType& rhs)
const {
109 if constexpr (util::is_scl_v<lhs_t> &&
110 util::is_scl_v<rhs_t>) {
111 return BinaryOp::fmap(lhs, rhs);
112 }
else if constexpr (util::is_scl_v<lhs_t>) {
113 return BinaryOp::fmap(lhs, rhs.array()).matrix();
114 }
else if constexpr (util::is_scl_v<rhs_t>) {
115 return BinaryOp::fmap(lhs.array(), rhs).matrix();
117 return BinaryOp::fmap(lhs.array(), rhs.array()).matrix();
127 template <
class Op,
class LHSType,
class RHSType>
135 lhs_t wrap_lhs_expr = lhs;
136 rhs_t wrap_rhs_expr = rhs;
140 return binary_t(wrap_lhs_expr, wrap_rhs_expr);
148 #undef PPL_BINOP_EQUAL_FIXED_SIZE
149 #undef PPL_BINOP_NO_MAT_SUPPORT
static constexpr bool has_param
Definition: binary.hpp:53
size_t cols() const
Definition: binary.hpp:79
size_t rows() const
Definition: binary.hpp:78
auto ad(const PtrPackType &pack) const
Definition: binary.hpp:82
constexpr auto operator_helper(const LHSType &lhs, const RHSType &rhs)
Definition: binary.hpp:128
void traverse(Func &&) const
Definition: binary.hpp:62
ad::util::max_shape_t< typename util::shape_traits< lhs_t >::shape_t, typename util::shape_traits< rhs_t >::shape_t > shape_t
Definition: binary.hpp:52
Definition: binary.hpp:36
auto get() const
Definition: binary.hpp:64
void activate_refcnt() const
Definition: binary.hpp:99
void bind(const PtrPackType &pack)
Definition: binary.hpp:89
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
std::common_type_t< typename util::var_expr_traits< lhs_t >::value_t, typename util::var_expr_traits< rhs_t >::value_t > value_t
Definition: binary.hpp:48
Definition: bounded.hpp:11
Definition: var_expr_traits.hpp:20
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
BinaryNode(const lhs_t &lhs, const rhs_t &rhs)
Definition: binary.hpp:56
size_t size() const
Definition: binary.hpp:77
auto eval()
Definition: binary.hpp:70