autoppl
v0.8
A C++ template library for probabilistic programming
|
Go to the documentation of this file.
7 #include <fastad_bits/reverse/core/var_view.hpp>
9 #define PPL_PARAMVIEW_SHAPE_UNSUPPORTED \
10 "Unsupported shape for ParamView. "
11 #define PPL_PARAM_SHAPE_UNSUPPORTED \
12 "Unsupported shape for Param. "
41 template <
class ValueType
77 transformer_.inv_transform(i_pack_->
refcnt);
78 return transformer_.get_c();
88 template <
class TPValPtrType
95 auto curr_pack = pack;
100 return transformer_.inv_transform_ad(curr_pack, pack,
104 template <
class TPValPtrType
105 ,
class TPAdjPtrType>
111 auto curr_pack = pack;
116 return transformer_.logj_inv_transform_ad(curr_pack, pack);
124 template <
class GenType,
class ContDist>
125 void init(GenType& gen, ContDist& dist)
126 { transformer_.init(gen, dist); }
139 pack.
uc_offset += transformer_.bind_size_uc();
140 pack.
c_offset += transformer_.bind_size_c();
141 pack.
v_offset += transformer_.bind_size_v();
154 transformer_.activate_refcnt(i_pack_->
refcnt);
163 template <
class PtrPackType>
164 void bind(
const PtrPackType& pack)
166 static_cast<void>(pack);
167 if constexpr (std::is_convertible_v<typename PtrPackType::uc_val_ptr_t, value_t*> &&
168 std::is_convertible_v<typename PtrPackType::c_val_ptr_t, value_t*>) {
171 size_t* vp = pack.v_val;
172 util::PtrPack curr_pack(ucp,
nullptr,
nullptr,
nullptr, cp, vp);
176 transformer_.bind(curr_pack, pack);
181 const var_t&
get()
const {
return transformer_.get_c(); }
182 constexpr
size_t size()
const {
return transformer_.size_c(); }
183 constexpr
size_t rows()
const {
return transformer_.rows_c(); }
184 constexpr
size_t cols()
const {
return transformer_.cols_c(); }
190 constexpr
size_t size_uc()
const {
return transformer_.size_uc(); }
191 constexpr
size_t size_c()
const {
return transformer_.size_c(); }
208 template <
class ValueType
210 ,
class ConstraintType = expr::constraint::Unconstrained>
213 static_assert(util::is_shape_v<ShapeType>,
217 template <
class ValueType,
class Constra
intType>
219 ParamView<ValueType, ppl::scl, ConstraintType>,
228 :
base_t(&i_pack_, 1, 1, c)
235 template <
class ValueType,
class Constra
intType>
237 ParamView<ValueType, ppl::vec, ConstraintType>,
246 :
base_t(&i_pack_, n, 1, c)
253 template <
class ValueType,
class Constra
intType>
255 ParamView<ValueType, ppl::mat, ConstraintType>,
272 template <
class ValueType
273 ,
class ShapeType =
scl
274 ,
class ConstraintType = expr::constraint::Unconstrained>
277 const ConstraintType& c = ConstraintType())
282 template <
class ValueType
283 ,
class ShapeType =
scl
284 ,
class ConstraintType = expr::constraint::Unconstrained>
286 const ConstraintType& c = ConstraintType())
291 template <
class ValueType
292 ,
class ShapeType =
scl
293 ,
class ConstraintType = expr::constraint::Unconstrained>
294 constexpr
inline auto make_param(
const ConstraintType& c = ConstraintType())
301 #undef PPL_PARAMVIEW_SHAPE_UNSUPPORTED
302 #undef PPL_PARAM_SHAPE_UNSUPPORTED
Param(size_t n, size_t=1, constraint_t c=constraint_t())
Definition: param.hpp:243
index_t c_offset
Definition: offset_pack.hpp:12
Definition: var_traits.hpp:16
void activate_refcnt() const
Definition: param.hpp:152
Definition: param.hpp:212
constexpr size_t size_uc() const
Definition: param.hpp:190
const void * id_t
Definition: param.hpp:52
util::var_t< value_t, shape_t > var_t
Definition: param.hpp:51
CValPtrType c_val
Definition: ptr_pack.hpp:37
constexpr auto make_param(size_t rows, size_t cols, const ConstraintType &c=ConstraintType())
Definition: param.hpp:275
auto & offset()
Definition: param.hpp:188
void traverse(Func &&) const
Definition: param.hpp:65
ppl::vec shape_t
Definition: param.hpp:49
util::OffsetPack off_pack
Definition: param.hpp:20
const var_t & eval()
Definition: param.hpp:76
#define PPL_PARAM_SHAPE_UNSUPPORTED
Definition: param.hpp:11
ad::vec vec
Definition: shape_traits.hpp:17
Definition: unconstrained.hpp:11
void init(GenType &gen, ContDist &dist)
Definition: param.hpp:125
auto offset() const
Definition: param.hpp:189
index_t v_offset
Definition: offset_pack.hpp:13
constexpr size_t size_c() const
Definition: param.hpp:191
auto ad(const util::PtrPack< value_t *, value_t *, TPValPtrType, TPAdjPtrType, value_t * > &pack) const
Definition: param.hpp:90
constexpr size_t cols(const T &x)
Definition: value.hpp:42
auto logj_ad(const util::PtrPack< value_t *, value_t *, TPValPtrType, TPAdjPtrType, value_t * > &pack) const
Definition: param.hpp:106
ad::scl scl
Definition: shape_traits.hpp:16
static constexpr bool has_param
Definition: param.hpp:53
ParamView(details::ParamInfoPack *i_pack, size_t rows=1, size_t cols=1, constraint_t c=constraint_t()) noexcept
Definition: param.hpp:55
constexpr size_t cols() const
Definition: param.hpp:184
UCValPtrType uc_val
Definition: ptr_pack.hpp:33
typename details::var< V, T >::type var_t
Definition: shape_traits.hpp:132
void inv_eval()
Definition: param.hpp:86
void activate(util::OffsetPack &pack) const
Definition: param.hpp:136
Param(size_t=1, size_t=1, constraint_t c=constraint_t()) noexcept
Definition: param.hpp:225
const var_t & get() const
Definition: param.hpp:181
constexpr size_t rows() const
Definition: param.hpp:183
index_t uc_offset
Definition: offset_pack.hpp:11
constexpr size_t size() const
Definition: param.hpp:182
ConstraintType constraint_t
Definition: param.hpp:50
Definition: bounded.hpp:11
Definition: var_expr_traits.hpp:20
ValueType value_t
Definition: param.hpp:48
ad::mat mat
Definition: shape_traits.hpp:18
Param(size_t rows, size_t cols, constraint_t c=constraint_t())
Definition: param.hpp:261
var_t & get()
Definition: param.hpp:180
constexpr size_t rows(const T &x)
Definition: value.hpp:32
Definition: offset_pack.hpp:9
size_t refcnt
Definition: param.hpp:19
Definition: ptr_pack.hpp:14
id_t id() const
Definition: param.hpp:185
void bind(const PtrPackType &pack)
Definition: param.hpp:164
size_t * v_val
Definition: ptr_pack.hpp:38