4#include "interpolation/rbf_incremental_fitter.h"
5#include "point_cloud/distance_filter.h"
6#include "boost/range/irange.hpp"
7#include "common/zip_sort.h"
8#include "fmm/fmm_tree_height.h"
9#include "interpolation/rbf_incremental_fitter.h"
10#include "interpolation/rbf_solver.h"
15#include <unordered_set>
19 rbf_incremental_fitter::rbf_incremental_fitter(
const model& model,
const geometry::points3d& points)
22 n_points_(points.rows()),
23 n_poly_basis_(model.poly_basis_size()),
24 bbox_(geometry::bbox3d::from_points(points)) {}
26 std::pair<std::vector<index_t>, valuesd> rbf_incremental_fitter::fit(
27 const valuesd& values,
double absolute_tolerance,
int max_iter)
const {
28 auto filtering_distance = bbox_.size().mean() / 4.0;
30 auto centers = point_cloud::distance_filter(points_, filtering_distance).filtered_indices();
31 auto n_centers =
static_cast<index_t
>(centers.size());
32 valuesd center_weights = valuesd::Zero(n_centers + n_poly_basis_);
34 std::unique_ptr<rbf_solver> solver;
35 std::unique_ptr<rbf_evaluator<>> res_eval;
36 auto last_tree_height = 0;
39 std::cout <<
"Number of RBF centers: " << n_centers <<
" / " << n_points_ << std::endl;
41 auto tree_height = fmm::fmm_tree_height(n_centers);
42 if (tree_height != last_tree_height) {
43 solver = std::make_unique<rbf_solver>(model_, tree_height, bbox_);
44 res_eval = std::make_unique<rbf_evaluator<>>(model_, tree_height, bbox_);
45 last_tree_height = tree_height;
48 geometry::points3d center_points = points_(centers, Eigen::all);
50 solver->set_points(center_points);
52 solver->solve(values(centers, Eigen::all), absolute_tolerance, max_iter, center_weights);
54 if (n_centers == n_points_) {
60 auto c_centers = complementary_indices(centers);
61 geometry::points3d c_center_points = points_(c_centers, Eigen::all);
63 res_eval->set_source_points(center_points);
64 res_eval->set_weights(center_weights);
66 auto c_values_fit = res_eval->evaluate(c_center_points);
67 valuesd c_values = values(c_centers, Eigen::all);
68 std::vector<double> c_residuals(c_centers.size());
69 valuesd::Map(c_residuals.data(),
static_cast<index_t
>(c_centers.size())) =
70 (c_values_fit - c_values).cwiseAbs();
74 common::zip_sort(c_centers.begin(), c_centers.end(), c_residuals.begin(), c_residuals.end(),
75 [](
const auto& a,
const auto& b) { return a.second < b.second; });
79 auto lb = std::lower_bound(c_residuals.begin(), c_residuals.end(), absolute_tolerance);
80 auto n_points_need_fitting =
static_cast<index_t
>(std::distance(lb, c_residuals.end()));
81 std::cout <<
"Number of points to fit: " << n_points_need_fitting << std::endl;
83 if (n_points_need_fitting == 0) {
89 auto n_last_centers = n_centers;
91 std::vector<index_t> indices(centers);
92 std::copy(c_centers.rbegin(), c_centers.rend(), std::back_inserter(indices));
93 point_cloud::distance_filter filter(points_, filtering_distance, indices);
94 std::unordered_set<index_t> filtered_indices(filter.filtered_indices().begin(),
95 filter.filtered_indices().end());
97 for (
auto it = c_centers.rbegin(); it != c_centers.rbegin() + n_points_need_fitting; ++it) {
98 if (filtered_indices.contains(*it)) {
99 centers.push_back(*it);
103 n_centers =
static_cast<index_t
>(centers.size());
105 auto last_center_weights = center_weights;
106 center_weights = valuesd::Zero(n_centers + n_poly_basis_);
107 center_weights.head(n_last_centers) = last_center_weights.head(n_last_centers);
108 center_weights.tail(n_poly_basis_) = last_center_weights.tail(n_poly_basis_);
110 filtering_distance *= 0.5;
113 return {std::move(centers), std::move(center_weights)};
116 std::vector<index_t> rbf_incremental_fitter::complementary_indices(
117 const std::vector<index_t>& indices)
const {
118 std::vector<index_t> c_idcs(n_points_ - indices.size());
120 auto universe = boost::irange<index_t>(index_t{0}, n_points_);
122 std::sort(idcs.begin(), idcs.end());
123 std::set_difference(universe.begin(), universe.end(), idcs.begin(), idcs.end(), c_idcs.begin());