-
Notifications
You must be signed in to change notification settings - Fork 0
Add grid refine #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add grid refine #107
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,54 @@ | ||||||
| #include "refine.hpp" | ||||||
|
|
||||||
| namespace snap { | ||||||
|
|
||||||
| torch::Tensor conservative_refine(torch::Tensor x) { | ||||||
| auto opts_y = torch::nn::functional::InterpolateFuncOptions() | ||||||
| .scale_factor(std::vector<double>({2.0, 2.0, 1.0})) | ||||||
| .mode(torch::kTrilinear) | ||||||
| .align_corners(false); | ||||||
|
|
||||||
| auto opts_x = torch::nn::functional::InterpolateFuncOptions() | ||||||
| .scale_factor(std::vector<double>({0.5, 0.5, 1.0})) | ||||||
| .mode(torch::kArea); | ||||||
|
|
||||||
| auto opts_dy = torch::nn::functional::InterpolateFuncOptions() | ||||||
| .scale_factor(std::vector<double>({2.0, 2.0, 1.0})) | ||||||
| .mode(torch::kArea); | ||||||
|
|
||||||
| // bilinear refine | ||||||
|
||||||
| // bilinear refine | |
| // trilinear refine |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,11 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <torch/torch.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace snap { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor conservative_refine(torch::Tensor x); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+7
to
+8
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor conservative_refine(torch::Tensor x); | |
| /** | |
| * @brief Perform conservative refinement of a tensor field. | |
| * | |
| * This utility takes values defined on a coarse resolution grid and produces | |
| * a refined representation while conserving the total quantity represented | |
| * by @p x along the refined dimensions. | |
| * | |
| * The input is typically a tensor of shape (batch, channels, spatial...), | |
| * where the trailing dimensions correspond to spatial axes that will be | |
| * refined. The exact refinement scheme and resulting shape depend on the | |
| * implementation, but the operation is designed to be conservative | |
| * (i.e., the sum over corresponding regions is preserved). | |
| * | |
| * @param x Input tensor containing cell-averaged or cell-integrated values | |
| * on a coarse grid. | |
| * @return A tensor defined on a refined grid, with values adjusted so that | |
| * the total quantity is conserved with respect to @p x. | |
| */ | |
| torch::Tensor conservative_refine(torch::Tensor x); | |
| /** | |
| * @brief Perform conservative coarsening of a tensor field. | |
| * | |
| * This utility takes values defined on a fine resolution grid and produces | |
| * a coarsened representation while conserving the total quantity represented | |
| * by @p x along the coarsened dimensions. | |
| * | |
| * The input is typically a tensor of shape (batch, channels, spatial...), | |
| * where the trailing dimensions correspond to spatial axes that will be | |
| * coarsened (e.g., by aggregating neighboring cells). The exact coarsening | |
| * scheme and resulting shape depend on the implementation, but the operation | |
| * is designed to be conservative (i.e., the sum over aggregated regions is | |
| * preserved). | |
| * | |
| * @param x Input tensor containing cell-averaged or cell-integrated values | |
| * on a fine grid. | |
| * @return A tensor defined on a coarser grid, with values adjusted so that | |
| * the total quantity is conserved with respect to @p x. | |
| */ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| // external | ||
| #include <gtest/gtest.h> | ||
| #include <yaml-cpp/yaml.h> | ||
|
|
||
| // torch | ||
| #include <torch/torch.h> | ||
|
|
||
| // snapy | ||
| #include <snap/snap.h> | ||
|
|
||
| #include <snap/utils/refine.hpp> | ||
|
|
||
| // tests | ||
| #include "device_testing.hpp" | ||
|
|
||
| using namespace snap; | ||
|
|
||
| TEST_P(DeviceTest, refine_funcs) { | ||
| int nc1 = 2; | ||
| int nc2 = 3; | ||
| int nc3 = 3; | ||
| int nvar = 1; | ||
|
|
||
| auto x = torch::empty({nvar, nc3, nc2, nc1}, torch::dtype(dtype)); | ||
|
|
||
| for (int n = 0; n < nvar; ++n) | ||
| for (int k = 0; k < nc3; ++k) | ||
| for (int j = 0; j < nc2; ++j) | ||
| for (int i = 0; i < nc1; ++i) { | ||
| x[n][k][j][i] = static_cast<float>(n + k + j + i + 1); | ||
| } | ||
|
|
||
| x = x.to(device); | ||
| auto y = conservative_refine(x); | ||
| auto z = conservative_coarsen(y); | ||
|
|
||
| EXPECT_TRUE(torch::allclose(x, z, 1.E-6, 1.E-6)); | ||
| } | ||
|
|
||
| int main(int argc, char **argv) { | ||
| testing::InitGoogleTest(&argc, argv); | ||
|
|
||
| return RUN_ALL_TESTS(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The break statements after return statements are unreachable code. Since each case returns a value, the break statements on lines 34, 37, and 40 will never be executed and should be removed.