2 #include <fastad_bits/reverse/core/dot.hpp>
5 #define PPL_DOT_MAT_VEC \
6 "Dot product is only supported for matrix as lhs argument " \
7 "and a matrix or vector as rhs argument. "
23 template <
class LHSVarExprType
24 ,
class RHSVarExprType>
28 using lhs_t = LHSVarExprType;
29 using rhs_t = RHSVarExprType;
31 static_assert(util::is_var_expr_v<lhs_t>);
32 static_assert(util::is_var_expr_v<rhs_t>);
33 static_assert(util::is_mat_v<lhs_t> &&
34 (util::is_vec_v<rhs_t> || util::is_mat_v<rhs_t>),
42 using shape_t = ad::core::details::dot_shape_t<lhs_t, rhs_t>;
44 lhs_t::has_param || rhs_t::has_param;
55 auto eval() {
return lhs_.eval() * rhs_.eval(); }
56 auto get() {
return lhs_.get() * rhs_.get(); }
58 size_t rows()
const {
return lhs_.rows(); }
59 size_t cols()
const {
return rhs_.cols(); }
61 template <
class PtrPackType>
62 auto ad(
const PtrPackType& pack)
const
68 template <
class PtrPackType>
69 void bind(
const PtrPackType& pack)
71 if constexpr (lhs_t::has_param) {
74 if constexpr (rhs_t::has_param) {
80 lhs_.activate_refcnt();
81 rhs_.activate_refcnt();
95 template <
class LHSVarExprType
96 ,
class RHSVarExprType
97 ,
class = std::enable_if_t<
98 (util::is_var_v<LHSVarExprType> ||
99 util::is_var_expr_v<LHSVarExprType>) &&
100 (util::is_var_v<RHSVarExprType> ||
101 util::is_var_expr_v<RHSVarExprType>)
103 inline constexpr
auto dot(
const LHSVarExprType& lhs,
104 const RHSVarExprType& rhs)
109 lhs_t wrap_lhs_expr = lhs;
110 rhs_t wrap_rhs_expr = rhs;
113 wrap_lhs_expr, wrap_rhs_expr);
118 #undef PPL_DOT_MAT_VEC