diff --git a/CMakeLists.txt b/CMakeLists.txt index 16fe2f9..0a80126 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,7 @@ message(STATUS "Include ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parameters.cmake") include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/parameters.cmake) message(STATUS "${Blue}${PROJECT_NAME}-3. Setting up system libraries ...${ColorReset}") +set(BUILD_SHARED_LIBS OFF) find_package(Torch REQUIRED) find_package(Disort REQUIRED) find_package(Harp REQUIRED) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 384b51a..efaff58 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -74,7 +74,7 @@ target_link_libraries(${namel}_${buildl} ${TORCH_LIBRARY} ${TORCH_CPU_LIBRARY} ${C10_LIBRARY} - archive + archive_static fmt::fmt yaml-cpp::yaml-cpp ) diff --git a/src/implicit/implicit_hydro.cpp b/src/implicit/implicit_hydro.cpp index c36d906..28f0e05 100644 --- a/src/implicit/implicit_hydro.cpp +++ b/src/implicit/implicit_hydro.cpp @@ -13,6 +13,36 @@ namespace snap { +ImplicitOptions ImplicitOptionsImpl::from_yaml(const std::string& filename, + bool /*verbose*/) { + auto config = YAML::LoadFile(filename); + if (!config["integration"]) return nullptr; + if (!config["integration"]["implicit-scheme"]) return nullptr; + return from_yaml(config["integration"]["implicit-scheme"]); +} + +ImplicitOptions ImplicitOptionsImpl::from_yaml(const YAML::Node& node) { + auto op = ImplicitOptionsImpl::create(); + op->scheme(node.as()); + return op; +} + +std::string ImplicitOptionsImpl::type() const { + switch (scheme()) { + case 0: + return "none"; + break; + case 1: + return "vic-partial"; + break; + case 9: + return "vic-full"; + break; + default: + TORCH_CHECK(false, "Unsupported implicit scheme"); + } +} + ImplicitHydroImpl::ImplicitHydroImpl(ImplicitOptions const& options_, torch::nn::Module* p) : options(options_) { diff --git a/src/implicit/implicit_hydro.hpp b/src/implicit/implicit_hydro.hpp index 381c63a..f6f7429 100644 --- a/src/implicit/implicit_hydro.hpp +++ b/src/implicit/implicit_hydro.hpp @@ -36,7 +36,7 @@ struct ImplicitOptionsImpl { } } - ADD_ARG(std::string, type) = "none"; + std::string type() const; ADD_ARG(int, scheme) = 0; }; using ImplicitOptions = std::shared_ptr; diff --git a/src/implicit/implicit_options.cpp b/src/implicit/implicit_options.cpp deleted file mode 100644 index 41f72db..0000000 --- a/src/implicit/implicit_options.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// yaml -#include - -// snap -#include "implicit_hydro.hpp" - -namespace snap { - -ImplicitOptions ImplicitOptionsImpl::from_yaml(const std::string& filename, - bool /*verbose*/) { - auto config = YAML::LoadFile(filename); - if (!config["integration"]) return nullptr; - if (!config["integration"]["implicit-scheme"]) return nullptr; - return from_yaml(config["integration"]["implicit-scheme"]); -} - -ImplicitOptions ImplicitOptionsImpl::from_yaml(const YAML::Node& node) { - auto op = ImplicitOptionsImpl::create(); - op->scheme(node.as()); - - switch (op->scheme()) { - case 0: - op->type("none"); - break; - case 1: - op->type("vic-partial"); - break; - case 9: - op->type("vic-full"); - break; - default: - TORCH_CHECK(false, "Unsupported implicit scheme"); - } - - return op; -} - -} // namespace snap diff --git a/src/utils/refine.cpp b/src/utils/refine.cpp new file mode 100644 index 0000000..19d6617 --- /dev/null +++ b/src/utils/refine.cpp @@ -0,0 +1,54 @@ +#include "refine.hpp" + +namespace snap { + +torch::Tensor conservative_refine(torch::Tensor x) { + auto opts_y = torch::nn::functional::InterpolateFuncOptions() + .scale_factor(std::vector({2.0, 2.0, 1.0})) + .mode(torch::kTrilinear) + .align_corners(false); + + auto opts_x = torch::nn::functional::InterpolateFuncOptions() + .scale_factor(std::vector({0.5, 0.5, 1.0})) + .mode(torch::kArea); + + auto opts_dy = torch::nn::functional::InterpolateFuncOptions() + .scale_factor(std::vector({2.0, 2.0, 1.0})) + .mode(torch::kArea); + + // bilinear refine + int dim = 0; + while (x.dim() < 5) { + ++dim; + x = x.unsqueeze(0); + } + auto y1 = torch::nn::functional::interpolate(x, opts_y); + + // conservative coarsen + auto x1 = torch::nn::functional::interpolate(y1, opts_x); + + // conservative correction + auto dy = torch::nn::functional::interpolate(x - x1, opts_dy); + auto y = y1 + dy; + + for (int i = 0; i < dim; ++i) y = y.squeeze(0); + return y; +} + +torch::Tensor conservative_coarsen(torch::Tensor x) { + auto opts = torch::nn::functional::InterpolateFuncOptions() + .scale_factor(std::vector({0.5, 0.5, 1.0})) + .mode(torch::kArea); + int dim = 0; + while (x.dim() < 5) { + ++dim; + x = x.unsqueeze(0); + } + + auto y = torch::nn::functional::interpolate(x, opts); + + for (int i = 0; i < dim; ++i) y = y.squeeze(0); + return y; +} + +} // namespace snap diff --git a/src/utils/refine.hpp b/src/utils/refine.hpp new file mode 100644 index 0000000..1f5bde0 --- /dev/null +++ b/src/utils/refine.hpp @@ -0,0 +1,11 @@ +#pragma once + +// torch +#include + +namespace snap { + +torch::Tensor conservative_refine(torch::Tensor x); +torch::Tensor conservative_coarsen(torch::Tensor x); + +} // namespace snap diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 04e9112..7e3a2a4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,6 +17,7 @@ setup_test(test_slab) setup_test(test_cubed) setup_test(test_cubed_sphere) setup_test(test_coordinate) +setup_test(test_refine) #setup_test(test_read_topo) #setup_cuda_test(test_thomas_solver) #setup_test(test_aneos) diff --git a/tests/test_refine.cpp b/tests/test_refine.cpp new file mode 100644 index 0000000..7e2b62b --- /dev/null +++ b/tests/test_refine.cpp @@ -0,0 +1,44 @@ +// external +#include +#include + +// torch +#include + +// snapy +#include + +#include + +// tests +#include "device_testing.hpp" + +using namespace snap; + +TEST_P(DeviceTest, refine_funcs) { + int nc1 = 2; + int nc2 = 3; + int nc3 = 3; + int nvar = 1; + + auto x = torch::empty({nvar, nc3, nc2, nc1}, torch::dtype(dtype)); + + for (int n = 0; n < nvar; ++n) + for (int k = 0; k < nc3; ++k) + for (int j = 0; j < nc2; ++j) + for (int i = 0; i < nc1; ++i) { + x[n][k][j][i] = static_cast(n + k + j + i + 1); + } + + x = x.to(device); + auto y = conservative_refine(x); + auto z = conservative_coarsen(y); + + EXPECT_TRUE(torch::allclose(x, z, 1.E-6, 1.E-6)); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +}