20 template <
class ProgramType
27 auto& model = program.get_model();
30 std::uniform_real_distribution cont_dist(-radius, radius);
31 auto init_params__ = [&](
auto& eq_node) {
32 auto& var = eq_node.get_variable();
33 using var_t = std::decay_t<decltype(var)>;
34 if constexpr (util::is_param_v<var_t>) {
35 var.init(gen, cont_dist);
38 model.traverse(init_params__);
43 int n_param_entities = 0;
44 auto n_param_entities__ = [&](
const auto& eq_node) {
45 auto& var = eq_node.get_variable();
46 using var_t = std::decay_t<decltype(var)>;
47 if constexpr (util::is_param_v<var_t>) {
51 model.traverse(n_param_entities__);
53 auto log_pdf = program.log_pdf();
54 for (
int i = 0; i < n_param_entities && log_pdf == math::neg_inf<double>; ++i) {
59 bool modified =
false;
60 auto prune_params__ = [&](
auto& eq_node) {
61 auto& var = eq_node.get_variable();
62 const auto& dist = eq_node.get_distribution();
63 using var_t = std::decay_t<decltype(var)>;
64 if constexpr (util::is_param_v<var_t>) {
65 bool curr_modified = dist.prune(var, gen);
69 modified = modified || curr_modified;
72 model.traverse(prune_params__);
78 log_pdf = program.log_pdf();
81 assert(log_pdf != math::neg_inf<double>);