Commit b6d082c7 authored by Lukas Riedel's avatar Lukas Riedel 📝

Add linear interpolator

* Adjust template parameters of interpolators.
* Add linear interpolator for 2D and 3D.
* Adjust unit test.

[skip ci] because code base is not adapted to template parameter
change.
parent b0290153
......@@ -7,8 +7,11 @@
#include <utility>
#include <numeric>
#include <algorithm>
#include <string>
#include <string_view>
#include <dune/common/fvector.hh>
#include <dune/common/fmatrix.hh>
#include <dune/common/exceptions.hh>
#include <dune/common/parametertree.hh>
......@@ -28,14 +31,16 @@ namespace Dorie {
* \author Lukas Riedel
* \date 2018
*/
template<typename T, typename Traits>
template<typename DataType, int dim>
class Interpolator
{
protected:
using Domain = typename Traits::Domain;
static constexpr int dim = Traits::dim;
/// Coordinate Type
using DF = double;
/// Physical vector type
using Domain = Dune::FieldVector<DF, dim>;
const std::vector<T> _data; //!< contiguously stored data of the array
const std::vector<DataType> _data; //!< contiguously stored data of the array
const std::vector<size_t> _shape; //!< inverted shape of the original array
//! physical extensions of the dataset
const Domain _extensions;
......@@ -100,7 +105,7 @@ public:
/** \param pos Position to evaluate
* \return Value of the interpolated function
*/
virtual T evaluate (const Domain& pos) const = 0;
virtual DataType evaluate (const Domain& pos) const = 0;
};
/// A nearest-neighbor interpolator in 2D and 3D
......@@ -110,34 +115,25 @@ public:
*
* \ingroup Interpolators
*/
template<typename T, typename Traits>
class NearestNeighborInterpolator: public Interpolator<T, Traits>
template<typename DataType, int dim>
class NearestNeighborInterpolator: public Interpolator<DataType, dim>
{
private:
using Base = Interpolator<T, Traits>;
using Base = Interpolator<DataType, dim>;
using Domain = typename Base::Domain;
static constexpr int dim = Base::dim;
using DF = typename Traits::DF;
using DF = typename Base::DF;
static constexpr DF eps = 1e-9;
public:
template<typename Data, typename Shape, typename Domain1, typename Domain2>
NearestNeighborInterpolator (Data&& data,
Shape&& shape,
Domain1&& extensions,
Domain2&& offset) :
Base(std::forward<Data>(data),
std::forward<Shape>(shape),
std::forward<Domain1>(extensions),
std::forward<Domain2>(offset))
{ }
/// Re-use the base class constructor
using Base::Base;
/// Export the type of this intepolator
static inline std::string type = "nearest";
/// Evaluate the interpolator
T evaluate (const Domain& x) const override
DataType evaluate (const Domain& x) const override
{
const auto pos = this->sub_offset(x);
if (not this->inside_extensions(pos)) {
......@@ -168,6 +164,163 @@ public:
~NearestNeighborInterpolator() override = default;
};
/// Linear interpolator for 2D and 3D.
/**
* \ingroup Interpolators
* \author Lukas Riedel
* \date 2018
*/
template<typename DataType, int dim>
class LinearInterpolator : public Interpolator<DataType, dim>
{
private:
using Base = Interpolator<DataType, dim>;
using Domain = typename Base::Domain;
using DF = typename Base::DF;
using MultiIdx = std::array<size_t, dim>;
/// Get the position vector in mesh units
/** \param pos Global position vector
* \return Position vector in mesh units
*/
Domain get_pos_unit (const Domain& pos) const
{
Domain pos_unit;
for (size_t i = 0; i < pos.size(); ++i) {
pos_unit[i] = pos[i] * (this->_shape[i] - 1)
/ this->_extensions[i];
}
return pos_unit;
}
/// Return the position difference vector in mesh units
/** The difference is calculated by using the (front) lower left corner
* as origin.
* \param pos_unit Position vector in mesh units, see get_pos_unit()
* \pararm indices The multi index of (front) lower left corner
* \return Position difference vector in mesh units
*/
Domain get_pos_unit_diff (const Domain& pos_unit, const MultiIdx& indices)
const
{
Domain pos_diff;
std::transform(pos_unit.begin(), pos_unit.end(),
begin(indices),
pos_diff.begin(),
std::minus());
return pos_diff;
}
/// Get the indices for accessing the data knots
/** \param pos_unit The position in mesh units
* \return Indices in every dimension of the mesh
*/
std::pair<MultiIdx, MultiIdx> get_knot_indices (
const Domain& pos_unit
) const
{
MultiIdx idx_lower;
std::transform(pos_unit.begin(), pos_unit.end(),
begin(idx_lower),
[](const auto value){
return std::max(std::floor(value), 0.0);
});
MultiIdx idx_upper;
std::transform(pos_unit.begin(), pos_unit.end(),
begin(this->_shape),
begin(idx_upper),
[](const auto value, const auto max){
return std::min<double>(std::ceil(value), max-1);
});
return std::make_pair(idx_lower, idx_upper);
}
/// Transform a multi index into a index for accessing the data array
/** \param indices The multi index to transform
* \return The index to query the data set #_data
*/
size_t to_index (const MultiIdx& indices) const
{
size_t index = indices[0] + indices[1] * this->_shape[0];
if constexpr (dim == 3) {
index += indices[2] * this->_shape[1] * this->_shape[0];
}
return index;
}
public:
/// Export the type of this intepolator
static inline constexpr std::string_view type = "linear";
/// Re-use the base class constructor
using Base::Base;
/// Delete this interpolator
~LinearInterpolator () override = default;
/// Evaluate the data at a global position
DataType evaluate (const Domain& x) const override
{
const auto pos = this->sub_offset(x);
if (not this->inside_extensions(pos)) {
DUNE_THROW(Exception, "Querying interpolator data outside its "
"extensions!");
}
// compute indices and position vectors in mesh units
const auto pos_unit = get_pos_unit(pos);
const auto [idx_lower, idx_upper] = get_knot_indices(pos);
const auto pos_unit_diff = get_pos_unit_diff(pos_unit, idx_lower);
const auto& data = this->_data;
DataType c_00, c_01, c_10, c_11;
DF x_d = 1.0, y_d, z_d;
// actual computation
if constexpr (dim == 2) {
c_00 = data[to_index(idx_lower)];
c_01 = data[to_index({idx_lower[0], idx_upper[1]})];
c_10 = data[to_index({idx_upper[0], idx_lower[1]})];
c_11 = data[to_index(idx_upper)];
// use y, z here for common computation after if-clause
y_d = pos_unit_diff[0];
z_d = pos_unit_diff[1];
}
else {
DataType c_000, c_001, c_010, c_100, c_011, c_110, c_101, c_111;
c_000 = data[to_index(idx_lower)];
c_100 = data[to_index({idx_upper[0], idx_lower[1], idx_lower[2]})];
c_010 = data[to_index({idx_lower[0], idx_upper[1], idx_lower[2]})];
c_001 = data[to_index({idx_lower[0], idx_lower[1], idx_upper[2]})];
c_011 = data[to_index({idx_lower[0], idx_upper[1], idx_upper[2]})];
c_110 = data[to_index({idx_upper[0], idx_upper[1], idx_lower[2]})];
c_101 = data[to_index({idx_upper[0], idx_lower[1], idx_upper[2]})];
c_111 = data[to_index(idx_upper)];
x_d = pos_unit_diff[0];
y_d = pos_unit_diff[1];
z_d = pos_unit_diff[2];
c_00 = c_000 * (1 - x_d) + c_100 * x_d;
c_01 = c_001 * (1 - x_d) + c_101 * x_d;
c_10 = c_010 * (1 - x_d) + c_110 * x_d;
c_11 = c_011 * (1 - x_d) + c_111 * x_d;
}
const DataType c_0 = c_00 * (1 - y_d) + c_10 * y_d;
const DataType c_1 = c_01 * (1 - y_d) + c_11 * y_d;
const DataType c = c_0 * (1 - z_d) + c_1 * z_d;
return c;
}
};
/// Factory for creating interpolators
/**
* \see Interpolator for the base class of created objects
......@@ -203,12 +356,13 @@ static auto create (
const std::shared_ptr<spdlog::logger> log=get_logger(log_base)
)
-> std::shared_ptr<Interpolator<
typename std::remove_reference_t<Data>::value_type,
Traits>>
typename std::decay_t<Data>::value_type,
Traits::dim>>
{
log->debug("Creating interpolator of type: {}", type);
using data_t = typename std::remove_reference_t<Data>::value_type;
using NNInterp = NearestNeighborInterpolator<data_t, Traits>;
using data_t = typename std::decay_t<Data>::value_type;
using NNInterp = NearestNeighborInterpolator<data_t, Traits::dim>;
using LinearInterp = LinearInterpolator<data_t, Traits::dim>;
if (type == NNInterp::type) {
return std::make_shared<NNInterp>(
......@@ -218,6 +372,14 @@ static auto create (
std::forward<Domain>(offset)
);
}
else if (type == LinearInterp::type) {
return std::make_shared<LinearInterp>(
std::forward<Data>(data),
std::forward<Shape>(shape),
std::forward<Domain>(extensions),
std::forward<Domain>(offset)
);
}
else {
log->error("Unknown interpolator type: {}", type);
DUNE_THROW(Dune::NotImplemented, "Unknown interpolator type");
......@@ -244,8 +406,8 @@ static auto create (
const std::shared_ptr<spdlog::logger> log=get_logger(log_base)
)
-> std::shared_ptr<Interpolator<
typename std::remove_reference_t<Data>::value_type,
Traits>>
typename std::decay_t<Data>::value_type,
Traits::dim>>
{
/// retrieve extensions and offset from grid extensions
const auto level_gv = grid_view.grid().levelGridView(0);
......@@ -272,7 +434,7 @@ static auto create (
const GridView& grid_view,
const std::shared_ptr<spdlog::logger> log=get_logger(log_base)
)
-> std::shared_ptr<Interpolator<RF, Traits>>
-> std::shared_ptr<Interpolator<RF, Traits::dim>>
{
const auto filename = config.get<std::string>("file");
const auto dataset = config.get<std::string>("dataset");
......@@ -300,7 +462,7 @@ static auto create (
const GridView& grid_view,
const std::shared_ptr<spdlog::logger> log=get_logger(log_base)
)
-> std::shared_ptr<Interpolator<RF, Traits>>
-> std::shared_ptr<Interpolator<RF, Traits::dim>>
{
log->trace("Creating interpolator from YAML cfg node: {}", node_name);
const auto dim = GridView::Grid::dimension;
......
......@@ -12,6 +12,7 @@
#include <dune/common/exceptions.hh>
#include <dune/common/float_cmp.hh>
#include <dune/common/parallel/mpihelper.hh>
#include <dune/common/fvector.hh>
#include <dune/dorie/common/interpolator.hh>
......@@ -19,7 +20,7 @@ template<int dimension>
struct InterpolatorTraits
{
static constexpr int dim = dimension;
using Domain = std::vector<double>;
using Domain = Dune::FieldVector<double, dim>;
using DF = double;
using RF = double;
};
......@@ -39,11 +40,11 @@ void test_nearest_neighbor ()
std::vector<size_t> shape(dim);
std::fill(begin(shape), end(shape), 3);
std::vector<double> extensions(dim);
std::fill(begin(extensions), end(extensions), 3.0);
Dune::FieldVector<double, dim> extensions(3.0);
// std::fill(begin(extensions), end(extensions), 3.0);
std::vector<double> offset(dim);
std::fill(begin(offset), end(offset), 0.0);
Dune::FieldVector<double, dim> offset(0.0);
// std::fill(begin(offset), end(offset), 0.0);
// build interpolator
auto interp = Dune::Dorie::InterpolatorFactory<InterpolatorTraits<dim>>
......@@ -55,7 +56,7 @@ void test_nearest_neighbor ()
// check without offset
using Dune::FloatCmp::eq; // floating-point comparison
std::vector<std::vector<double>> corners;
std::vector<Dune::FieldVector<double, dim>> corners;
if constexpr (dim == 2) {
corners.resize(4);
corners[0] = {0.0, 0.0};
......@@ -88,7 +89,7 @@ void test_nearest_neighbor ()
}
// check with offset
std::fill(begin(offset), end(offset), -1.0);
std::fill(offset.begin(), offset.end(), -1.0);
interp = Dune::Dorie::InterpolatorFactory<InterpolatorTraits<dim>>
::create("nearest",
data,
......@@ -128,16 +129,130 @@ void test_nearest_neighbor ()
}
}
/// Test the linear interpolator
template<int dim>
void test_linear ();
/// Test the linear interpolator in 2D
template<>
void test_linear<2> ()
{
constexpr int dim = 2;
std::vector<double> data({0.0, 1.0, 1.0, 2.0});
std::vector<size_t> shape(dim);
std::fill(begin(shape), end(shape), 2);
Dune::FieldVector<double, dim> extensions(1.0);
Dune::FieldVector<double, dim> offset(0.0);
// build interpolator
auto interp = Dune::Dorie::InterpolatorFactory<InterpolatorTraits<dim>>
::create("linear",
data,
shape,
extensions,
offset);
// check without offset
using Dune::FloatCmp::eq; // floating-point comparison
std::vector<Dune::FieldVector<double, dim>> points(5);
points[0] = {0.0, 0.0};
points[1] = {1.0, 0.0};
points[2] = {0.0, 1.0};
points[3] = {1.0, 1.0};
points[4] = {0.75, 0.75};
assert(eq(interp->evaluate(points[0]), 0.0));
assert(eq(interp->evaluate(points[1]), 1.0));
assert(eq(interp->evaluate(points[2]), 1.0));
assert(eq(interp->evaluate(points[3]), 2.0));
assert(eq(interp->evaluate(points[4]), 1.5));
// check with offset
std::fill(offset.begin(), offset.end(), -1.0);
interp = Dune::Dorie::InterpolatorFactory<InterpolatorTraits<dim>>
::create("linear",
data,
shape,
extensions,
offset);
points.resize(2);
points[0] = {0.0, 0.0};
points[1] = {-0.25, -0.25};
assert(eq(interp->evaluate(points[0]), 2.0));
assert(eq(interp->evaluate(points[1]), 1.5));
}
/// Test the linear interpolator in 3D
template<>
void test_linear<3> ()
{
constexpr int dim = 3;
std::vector<double> data({0.0, 0.0, 0.0, 0.0,
1.0, 1.0, 2.0, 2.0});
std::vector<size_t> shape(dim);
std::fill(begin(shape), end(shape), 2);
Dune::FieldVector<double, dim> extensions(1.0);
Dune::FieldVector<double, dim> offset(0.0);
// build interpolator
auto interp = Dune::Dorie::InterpolatorFactory<InterpolatorTraits<dim>>
::create("linear",
data,
shape,
extensions,
offset);
// check without offset
using Dune::FloatCmp::eq; // floating-point comparison
std::vector<Dune::FieldVector<double, dim>> points(5);
points[0] = {0.0, 0.0, 0.0};
points[1] = {1.0, 0.0, 0.0};
points[2] = {0.0, 0.0, 1.0};
points[3] = {1.0, 1.0, 1.0};
points[4] = {0.5, 0.5, 1.0};
assert(eq(interp->evaluate(points[0]), 0.0));
assert(eq(interp->evaluate(points[1]), 0.0));
assert(eq(interp->evaluate(points[2]), 1.0));
assert(eq(interp->evaluate(points[3]), 2.0));
assert(eq(interp->evaluate(points[4]), 1.5));
// check with offset
std::fill(offset.begin(), offset.end(), -1.0);
interp = Dune::Dorie::InterpolatorFactory<InterpolatorTraits<dim>>
::create("linear",
data,
shape,
extensions,
offset);
points.resize(2);
points[0] = {0.0, 0.0, 0.0};
points[1] = {-0.5, -0.5, 0.0};
assert(eq(interp->evaluate(points[0]), 2.0));
assert(eq(interp->evaluate(points[1]), 1.5));
}
int main (int argc, char** argv)
{
try{
// initialize MPI if needed
auto& helper = Dune::MPIHelper::instance(argc, argv);
Dune::Dorie::create_base_logger(helper);
auto log = Dune::Dorie::create_base_logger(helper);
log->set_level(spdlog::level::trace);
// test the NearestNeighbor interpolator
test_nearest_neighbor<2>();
test_nearest_neighbor<3>();
test_linear<2>();
test_linear<3>();
}
catch (Dune::Exception &e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment