autoppl  v0.8
A C++ template library for probabilistic programming
for_each.hpp
Go to the documentation of this file.
1 #pragma once
2 #include <fastad_bits/reverse/core/for_each.hpp>
4 
5 namespace ppl {
6 namespace expr {
7 namespace var {
8 
16 template <class VecExprType>
17 struct ForEachNode:
18  util::VarExprBase<ForEachNode<VecExprType>>
19 {
20 private:
21  using vec_expr_t = VecExprType;
22  using elt_t = std::decay_t<typename VecExprType::value_type>;
23 
24  static_assert(util::is_var_expr_v<elt_t>);
25 
26 public:
29  static constexpr bool has_param =
31 
32  ForEachNode(const vec_expr_t& vec_expr)
33  : vec_expr_{vec_expr}
34  {}
35 
36  template <class Func>
37  void traverse(Func&& f)
38  {
39  for (auto& expr: vec_expr_) expr.traverse(f);
40  }
41 
42  template <class Func>
43  void traverse(Func&& f) const
44  {
45  for (const auto& expr: vec_expr_) expr.traverse(f);
46  }
47 
48  auto get() const {
49  assert(!vec_expr_.empty());
50  return vec_expr_.back().get();
51  }
52 
53  auto eval() {
54  for (auto& expr : vec_expr_) expr.eval();
55  return get();
56  }
57 
58  constexpr size_t size() const { return vec_expr_.empty() ? 0 : vec_expr_[0].size(); }
59  constexpr size_t rows() const { return vec_expr_.empty() ? 0 : vec_expr_[0].rows(); }
60  constexpr size_t cols() const { return vec_expr_.empty() ? 0 : vec_expr_[0].cols(); }
61 
62  template <class PtrPackType>
63  auto ad(const PtrPackType& pack) const
64  {
65  return ad::for_each(vec_expr_.begin(),
66  vec_expr_.end(),
67  [&](const auto& expr) {
68  return expr.ad(pack);
69  });
70  }
71 
72  template <class PtrPackType>
73  void bind(const PtrPackType& pack)
74  {
75  if constexpr (elt_t::has_param) {
76  for (auto& expr : vec_expr_) expr.bind(pack);
77  }
78  }
79 
80  void activate_refcnt() const {
81  for (const auto& expr : vec_expr_) expr.activate_refcnt();
82  }
83 
84 private:
85  vec_expr_t vec_expr_;
86 };
87 
88 } // namespace var
89 } // namespace expr
90 
91 template <class Iter
92  , class F>
93 inline constexpr auto for_each(Iter begin,
94  Iter end,
95  F f)
96 {
97  using iter_elt_t = std::decay_t<typename std::iterator_traits<Iter>::value_type>;
98  using ret_t = std::invoke_result_t<F, iter_elt_t>;
99  using expr_t = util::convert_to_param_t<ret_t>;
100  std::vector<expr_t> exprs;
101  exprs.reserve(std::distance(begin, end));
102 
103  std::for_each(begin, end, [&](auto&& x) { exprs.emplace_back(f(x)); });
104 
106 }
107 
108 } // namespace ppl
ppl::expr::var::ForEachNode::has_param
static constexpr bool has_param
Definition: for_each.hpp:29
ppl::util::var_expr_traits
Definition: var_expr_traits.hpp:28
ppl::expr::var::ForEachNode::ForEachNode
ForEachNode(const vec_expr_t &vec_expr)
Definition: for_each.hpp:32
ppl::for_each
constexpr auto for_each(Iter begin, Iter end, F f)
Definition: for_each.hpp:93
ppl::expr::var::ForEachNode
Definition: for_each.hpp:19
ppl::expr::var::ForEachNode::ad
auto ad(const PtrPackType &pack) const
Definition: for_each.hpp:63
ppl::expr::var::ForEachNode::traverse
void traverse(Func &&f) const
Definition: for_each.hpp:43
ppl::expr::var::ForEachNode::eval
auto eval()
Definition: for_each.hpp:53
ppl::expr::var::ForEachNode::get
auto get() const
Definition: for_each.hpp:48
ppl::util::var_expr_traits::value_t
typename VarExprType::value_t value_t
Definition: var_expr_traits.hpp:29
ppl::expr::var::ForEachNode::rows
constexpr size_t rows() const
Definition: for_each.hpp:59
ppl::expr::var::ForEachNode::value_t
typename util::var_expr_traits< elt_t >::value_t value_t
Definition: for_each.hpp:27
ppl::expr::var::ForEachNode::activate_refcnt
void activate_refcnt() const
Definition: for_each.hpp:80
ppl::util::shape_traits
ad::util::shape_traits< T > shape_traits
Definition: shape_traits.hpp:23
ppl
Definition: bounded.hpp:11
ppl::util::VarExprBase
Definition: var_expr_traits.hpp:20
ppl::expr::var::ForEachNode::shape_t
typename util::shape_traits< elt_t >::shape_t shape_t
Definition: for_each.hpp:28
ppl::util::convert_to_param_t
typename details::convert_to_param< T >::type convert_to_param_t
Definition: traits.hpp:148
ppl::expr::var::ForEachNode::size
constexpr size_t size() const
Definition: for_each.hpp:58
traits.hpp
ppl::expr::var::ForEachNode::cols
constexpr size_t cols() const
Definition: for_each.hpp:60
ppl::expr::var::ForEachNode::bind
void bind(const PtrPackType &pack)
Definition: for_each.hpp:73
ppl::expr::var::ForEachNode::traverse
void traverse(Func &&f)
Definition: for_each.hpp:37