Commit 6ff38911 authored by Lukas Riedel's avatar Lukas Riedel

Extend factory to create interpolators from config nodes directly

Move code previously located in the ScalingAdapter base class into the
InterpolatorFactory. Destinguish between YAML nodes and Dune
ParameterTree nodes.
parent fb61a631
......@@ -10,9 +10,13 @@
#include <dune/common/fvector.hh>
#include <dune/common/exceptions.hh>
#include <dune/common/parametertree.hh>
#include <dune/dorie/common/logging.hh>
#include <dune/dorie/common/utility.hh>
#include <dune/dorie/common/h5file.hh>
#include <yaml-cpp/yaml.h>
namespace Dune {
namespace Dorie {
......@@ -175,6 +179,9 @@ template<typename Traits>
struct InterpolatorFactory
{
using RF = typename Traits::RF;
using Domain = typename Traits::Domain;
/// Create the interpolator.
/** Use perfect forwarding for the data.
* \param type The type of interpolator to use
......@@ -254,6 +261,135 @@ static auto create (
log);
}
/// Create an interpolator from a ini file config
/** \param config Dune config file tree
* \param grid_view The grid view to use for determining extensions
* \param log The logger for the created object
*/
template<typename GridView>
static auto create (
const Dune::ParameterTree& config,
const GridView& grid_view,
const std::shared_ptr<spdlog::logger> log=get_logger(log_base)
)
-> std::shared_ptr<Interpolator<RF, Traits>>
{
const auto filename = config.get<std::string>("file");
const auto dataset = config.get<std::string>("dataset");
const auto [data, shape] = read_h5_dataset(filename, dataset, log);
const auto interpolation = config.get<std::string>("interpolation");
return create(interpolation,
data,
shape,
grid_view,
log);
}
/// Create an interpolator from a YAML config node
/// Create an interpolator from a ini file config
/** \param cfg YAML config tree node
* \param node_name Name of the YAML config node for better error messages
* \param grid_view The grid view to use for determining extensions
* \param log The logger for the created object
*/
template<typename GridView>
static auto create (
const YAML::Node& cfg,
const std::string& node_name,
const GridView& grid_view,
const std::shared_ptr<spdlog::logger> log=get_logger(log_base)
)
-> std::shared_ptr<Interpolator<RF, Traits>>
{
log->trace("Creating interpolator from YAML cfg node: {}", node_name);
const auto dim = GridView::Grid::dimension;
// open the H5 file
const auto filename = cfg["file"].as<std::string>();
const auto dataset = cfg["dataset"].as<std::string>();
const auto [data, shape] = read_h5_dataset(filename, dataset, log);
// read extensions and offset
const auto cfg_ext = cfg["extensions"];
const auto cfg_off = cfg["offset"];
if (cfg_ext and cfg_off)
{
const auto ext = cfg_ext.as<std::vector<RF>>();
const auto off = cfg_off.as<std::vector<RF>>();
if (shape.size() != dim) {
log->error("Expected {}-dimensional dataset. File: {}, "
"Dataset: {}, Dimensions: {}",
dim, filename, dataset, shape.size());
DUNE_THROW(Dune::IOError, "Invalid dataset for interpolator");
}
if (ext.size() != dim) {
log->error("Expected {}-dimensional sequence as "
"'extensions' at scaling section node: {}",
dim, node_name);
DUNE_THROW(Dune::IOError, "Invalid interpolator input");
}
if (off.size() != dim) {
log->error("Expected {}-dimensional sequence as 'offset' "
"at scaling section node: {}",
dim, node_name);
DUNE_THROW(Dune::IOError, "Invalid interpolator input");
}
// copy into correct data structure
Domain extensions;
std::copy(begin(ext), end(ext), extensions.begin());
Domain offset;
std::copy(begin(off), end(off), offset.begin());
// call other factory function
return create(
cfg["interpolation"].template as<std::string>(),
data,
shape,
extensions,
offset,
log
);
}
// extensions or offset not given
else
{
// call other factory function
return create(
cfg["interpolation"].template as<std::string>(),
data,
shape,
grid_view,
log
);
}
}
private:
/// Read a dataset from an H5 file
/** \param file_path Path to the H5 file
* \param dataset_name Name of the dataset
* \return Pair: 1D data, dataset shape.
*/
static std::pair<std::vector<RF>, std::vector<size_t>> read_h5_dataset (
const std::string& file_path,
const std::string& dataset_name,
const std::shared_ptr<spdlog::logger> log
)
{
// open file
H5File h5file(file_path, log);
// read the dataset
std::vector<RF> data;
std::vector<size_t> shape;
h5file.read_dataset(dataset_name, H5T_NATIVE_DOUBLE, data, shape);
std::reverse(begin(shape), end(shape));
return std::make_pair(data, shape);
}
};
} // namespace Dorie
......
......@@ -90,78 +90,11 @@ protected:
*/
void add_interpolator (const YAML::Node& cfg, const std::string node_name)
{
_log->trace("Adding interpolator from data node: {}", node_name);
const auto dim = Traits::dim;
// open the H5 file
const auto filename = cfg["file"].as<std::string>();
const auto dataset = cfg["dataset"].as<std::string>();
H5File h5file(filename, _log);
// read the dataset
std::vector<RF> data;
std::vector<size_t> shape;
h5file.read_dataset(dataset, H5T_NATIVE_DOUBLE, data, shape);
std::reverse(begin(shape), end(shape));
// read extensions and offset
const auto cfg_ext = cfg["extensions"];
const auto cfg_off = cfg["offset"];
if (cfg_ext and cfg_off)
{
const auto ext = cfg_ext.as<std::vector<RF>>();
const auto off = cfg_off.as<std::vector<RF>>();
if (shape.size() != dim) {
_log->error("Expected {}-dimensional dataset. File: {}, "
"Dataset: {}, Dimensions: {}",
dim, filename, dataset, shape.size());
DUNE_THROW(Dune::IOError, "Invalid dataset for interpolator");
}
if (ext.size() != dim) {
_log->error("Expected {}-dimensional sequence as "
"'extensions' at scaling section node: {}",
dim, node_name);
DUNE_THROW(Dune::IOError, "Invalid interpolator input");
}
if (off.size() != dim) {
_log->error("Expected {}-dimensional sequence as 'offset' "
"at scaling section node: {}",
dim, node_name);
DUNE_THROW(Dune::IOError, "Invalid interpolator input");
}
// copy into correct data structure
Domain extensions;
std::copy(begin(ext), end(ext), extensions.begin());
Domain offset;
std::copy(begin(off), end(off), offset.begin());
// build interpolator
this->_interpolators.emplace_back(
InterpolatorFactory<Traits>::create(
cfg["interpolation"].template as<std::string>(),
data,
shape,
extensions,
offset,
_log
)
);
}
// extensions or offset not given
else
{
// build interpolator
this->_interpolators.emplace_back(
InterpolatorFactory<Traits>::create(
cfg["interpolation"].template as<std::string>(),
data,
shape,
_grid_view,
_log
)
);
}
this->_interpolators.emplace_back(
InterpolatorFactory<Traits>::create(cfg,
node_name,
_grid_view,
_log));
}
};
......
......@@ -21,6 +21,7 @@ struct InterpolatorTraits
static constexpr int dim = dimension;
using Domain = std::vector<double>;
using DF = double;
using RF = double;
};
/// Test the nearest neighbor interpolator.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment