From e603a7e32bc6812e06824210aab5dcc3f7750978 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 5 Feb 2026 05:20:05 -0800 Subject: [PATCH 1/2] changing _precompute_latitudes and _precompute_longitudes to their non-underscored counterparts. Furthermore, removing unused imports from several files --- examples/segmentation/train.py | 3 +-- notebooks/equivariance_test.ipynb | 12 +++++----- notebooks/resample_sphere.ipynb | 12 +++++----- tests/test_convolution.py | 10 ++++---- torch_harmonics/attention/attention.py | 9 ++++---- torch_harmonics/disco/convolution.py | 13 ++++------- .../distributed/distributed_convolution.py | 7 +----- .../distributed/distributed_resample.py | 10 ++++---- torch_harmonics/distributed/primitives.py | 2 +- torch_harmonics/examples/losses.py | 4 ++-- torch_harmonics/examples/metrics.py | 1 - torch_harmonics/examples/models/_layers.py | 23 ------------------- torch_harmonics/examples/models/lsno.py | 2 +- .../examples/models/s2segformer.py | 4 +--- .../examples/models/s2transformer.py | 3 +-- torch_harmonics/examples/models/s2unet.py | 5 +--- torch_harmonics/examples/models/sfno.py | 2 -- torch_harmonics/examples/pde_sphere.py | 4 ++-- .../examples/shallow_water_equations.py | 4 ++-- .../examples/stanford_2d3ds_dataset.py | 6 ++--- torch_harmonics/quadrature.py | 10 ++++---- torch_harmonics/resample.py | 14 +++++------ 22 files changed, 59 insertions(+), 101 deletions(-) diff --git a/examples/segmentation/train.py b/examples/segmentation/train.py index cf05918a..1346069a 100644 --- a/examples/segmentation/train.py +++ b/examples/segmentation/train.py @@ -51,8 +51,7 @@ import matplotlib.pyplot as plt from torch_harmonics.examples import StanfordSegmentationDataset, Stanford2D3DSDownloader, StanfordDatasetSubset, compute_stats_s2 -from torch_harmonics.quadrature import _precompute_latitudes -from torch_harmonics.examples.losses import DiceLossS2, CrossEntropyLossS2, FocalLossS2 +from torch_harmonics.examples.losses import CrossEntropyLossS2 from torch_harmonics.examples.metrics import IntersectionOverUnionS2, AccuracyS2 from torch_harmonics.plotting import plot_sphere, imshow_sphere diff --git a/notebooks/equivariance_test.ipynb b/notebooks/equivariance_test.ipynb index 35fe25ba..eb0cbb2d 100644 --- a/notebooks/equivariance_test.ipynb +++ b/notebooks/equivariance_test.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "40e4a61f", "metadata": {}, "outputs": [], @@ -35,7 +35,7 @@ "\n", "\n", "from torch_harmonics import AttentionS2\n", - "from torch_harmonics.quadrature import _precompute_latitudes\n", + "from torch_harmonics.quadrature import precompute_latitudes\n", "from torch_harmonics.plotting import plot_sphere" ] }, @@ -340,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "281ebfe8", "metadata": {}, "outputs": [ @@ -415,7 +415,7 @@ "\n", " # Assume quadrature weights are given or use Gaussian quadrature as done in torch-harmonics for accuracy\n", " # quadrature_weights = np.ones((nlat, nlon)) * (4*np.pi) / (nlat*nlon) # Placeholder\n", - " _, quad_weights = _precompute_latitudes(nlat, grid=grid)\n", + " _, quad_weights = precompute_latitudes(nlat, grid=grid)\n", " quad_weights = 2 * np.pi * quad_weights.reshape(-1, 1) / nlon\n", "\n", " # input signal is a spherical harmonic\n", @@ -460,7 +460,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "16f492be", "metadata": {}, "outputs": [ @@ -543,7 +543,7 @@ "\n", " # Assume quadrature weights are given or use Gaussian quadrature as done in torch-harmonics for accuracy\n", " # quadrature_weights = np.ones((nlat, nlon)) * (4*np.pi) / (nlat*nlon) # Placeholder\n", - " _, quad_weights = _precompute_latitudes(nlat, grid=grid)\n", + " _, quad_weights = precompute_latitudes(nlat, grid=grid)\n", " quad_weights = 2 * np.pi * quad_weights.reshape(-1, 1) / nlon\n", "\n", " # input signal is a spherical harmonic\n", diff --git a/notebooks/resample_sphere.ipynb b/notebooks/resample_sphere.ipynb index a87a15ec..914266d7 100644 --- a/notebooks/resample_sphere.ipynb +++ b/notebooks/resample_sphere.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -22,7 +22,7 @@ "import torch.nn as nn\n", "\n", "from torch_harmonics import ResampleS2\n", - "from torch_harmonics.quadrature import _precompute_latitudes\n", + "from torch_harmonics.quadrature import precompute_latitudes\n", "\n", "import matplotlib.pyplot as plt\n", "from torch_harmonics.plotting import plot_sphere" @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -45,7 +45,7 @@ "nlon = 2*(nlat-1)\n", "grid = \"equiangular\"\n", "\n", - "xq, wq = _precompute_latitudes(nlat, grid=grid)" + "xq, wq = precompute_latitudes(nlat, grid=grid)" ] }, { @@ -351,7 +351,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "dace", "language": "python", "name": "python3" }, @@ -365,7 +365,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/tests/test_convolution.py b/tests/test_convolution.py index fcbf4a6d..4e9f02f9 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -39,7 +39,7 @@ from torch.library import opcheck from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 -from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes from torch_harmonics.disco import cuda_kernels_is_available, optimized_kernels_is_available from testutils import disable_tf32, set_seed, compare_tensors @@ -120,12 +120,12 @@ def _precompute_convolution_tensor_dense( nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) - lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) + lats_in, win = precompute_latitudes(nlat_in, grid=grid_in) + lats_out, wout = precompute_latitudes(nlat_out, grid=grid_out) # compute the phi differences. - lons_in = _precompute_longitudes(nlon_in) - lons_out = _precompute_longitudes(nlon_out) + lons_in = precompute_longitudes(nlon_in) + lons_out = precompute_longitudes(nlon_out) # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index d40fdc9d..42841a38 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -29,15 +29,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import List, Tuple, Union, Optional -from warnings import warn +from typing import Tuple, Union, Optional import math import torch import torch.nn as nn -from torch_harmonics.quadrature import _precompute_latitudes +from torch_harmonics.quadrature import precompute_latitudes from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2 from torch_harmonics.attention._attention_utils import _neighborhood_s2_attention_torch, _neighborhood_s2_attention_optimized from torch_harmonics.filter_basis import get_filter_basis @@ -96,7 +95,7 @@ def __init__( self.scale = scale # integration weights - _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) + _, wgl = precompute_latitudes(self.nlat_in, grid=grid_in) quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in # we need to tile and flatten them accordingly quad_weights = torch.tile(quad_weights.reshape(-1, 1), (1, self.nlon_in)).flatten() @@ -248,7 +247,7 @@ def __init__( raise ValueError("Error, theta_cutoff has to be positive.") # integration weights - _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) + _, wgl = precompute_latitudes(self.nlat_in, grid=grid_in) quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 30c891e7..7c716bb8 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -30,18 +30,15 @@ # import abc -from typing import List, Tuple, Union, Optional -from warnings import warn +from typing import Tuple, Union, Optional import math import torch import torch.nn as nn -from functools import partial - from torch_harmonics.cache import lru_cache -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes from ._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from ._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized from torch_harmonics.filter_basis import FilterBasis, get_filter_basis @@ -231,12 +228,12 @@ def _precompute_convolution_tensor_s2( nlat_out, nlon_out = out_shape # precompute input and output grids - lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) - lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) + lats_in, win = precompute_latitudes(nlat_in, grid=grid_in) + lats_out, wout = precompute_latitudes(nlat_out, grid=grid_out) # compute the phi differences # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 - lons_in = _precompute_longitudes(nlon_in) + lons_in = precompute_longitudes(nlon_in) # compute quadrature weights and merge them into the convolution tensor. # These quadrature integrate to 1 over the sphere. diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 7d2bd568..ba611035 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -29,18 +29,13 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import List, Tuple, Union, Optional +from typing import Tuple, Union, Optional from itertools import accumulate import torch -import torch.nn as nn -from functools import partial - -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.disco._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics.disco._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized -from torch_harmonics.filter_basis import get_filter_basis from disco_helpers import optimized_kernels_is_available, preprocess_psi from torch_harmonics.disco.convolution import ( _precompute_convolution_tensor_s2, diff --git a/torch_harmonics/distributed/distributed_resample.py b/torch_harmonics/distributed/distributed_resample.py index 114f57a0..76978ffd 100644 --- a/torch_harmonics/distributed/distributed_resample.py +++ b/torch_harmonics/distributed/distributed_resample.py @@ -35,7 +35,7 @@ import torch import torch.nn as nn -from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import reduce_from_azimuth_region, copy_to_azimuth_region from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank @@ -107,10 +107,10 @@ def __init__( self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth) # for upscaling the latitudes we will use interpolation - self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) - self.lons_in = _precompute_longitudes(nlon_in) - self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) - self.lons_out = _precompute_longitudes(nlon_out) + self.lats_in, _ = precompute_latitudes(nlat_in, grid=grid_in) + self.lons_in = precompute_longitudes(nlon_in) + self.lats_out, _ = precompute_latitudes(nlat_out, grid=grid_out) + self.lons_out = precompute_longitudes(nlon_out) # in the case where some points lie outside of the range spanned by lats_in, # we need to expand the solution to the poles before interpolating diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index 50b76ace..e1e29abb 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -35,7 +35,7 @@ from torch.amp import custom_fwd, custom_bwd from .utils import polar_group, azimuth_group, polar_group_size -from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth +from .utils import is_distributed_polar, is_distributed_azimuth # helper routine to compute uneven splitting in balanced way: def compute_split_shapes(size: int, num_chunks: int) -> List[int]: diff --git a/torch_harmonics/examples/losses.py b/torch_harmonics/examples/losses.py index c3c2065b..58f75567 100644 --- a/torch_harmonics/examples/losses.py +++ b/torch_harmonics/examples/losses.py @@ -36,12 +36,12 @@ from typing import Optional from abc import ABC, abstractmethod -from torch_harmonics.quadrature import _precompute_latitudes +from torch_harmonics.quadrature import precompute_latitudes def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor: # area weights - _, q = _precompute_latitudes(nlat=nlat, grid=grid) + _, q = precompute_latitudes(nlat=nlat, grid=grid) q = q.reshape(-1, 1) * 2 * torch.pi / nlon # numerical precision can be an issue here, make sure it sums to 1: diff --git a/torch_harmonics/examples/metrics.py b/torch_harmonics/examples/metrics.py index 9d98bc0a..ed0facf1 100644 --- a/torch_harmonics/examples/metrics.py +++ b/torch_harmonics/examples/metrics.py @@ -34,7 +34,6 @@ import torch import torch.nn as nn -from torch_harmonics.quadrature import _precompute_latitudes from .losses import get_quadrature_weights diff --git a/torch_harmonics/examples/models/_layers.py b/torch_harmonics/examples/models/_layers.py index 4c3bed7f..779d8af2 100644 --- a/torch_harmonics/examples/models/_layers.py +++ b/torch_harmonics/examples/models/_layers.py @@ -596,26 +596,3 @@ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_ else: raise ValueError(f"Unknown learnable position embedding type {embed_type}") -# class SpiralPositionEmbedding(PositionEmbedding): -# """ -# Returns position embeddings on the torus -# """ - -# def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): - -# super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) - -# with torch.no_grad(): - -# # alternating custom position embeddings -# lats, _ = _precompute_latitudes(img_shape[0], grid=grid) -# lats = lats.reshape(-1, 1) -# lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1] -# lons = lons.reshape(1, -1) - -# # channel index -# k = torch.arange(self.num_chans).reshape(1, -1, 1, 1) -# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats))) - -# # register tensor -# self.register_buffer("position_embeddings", pos_embed.float()) diff --git a/torch_harmonics/examples/models/lsno.py b/torch_harmonics/examples/models/lsno.py index 9c1cc339..53b3be21 100644 --- a/torch_harmonics/examples/models/lsno.py +++ b/torch_harmonics/examples/models/lsno.py @@ -36,7 +36,7 @@ import torch.amp as amp from torch_harmonics import RealSHT, InverseRealSHT -from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 +from torch_harmonics import DiscreteContinuousConvS2 from torch_harmonics import ResampleS2 from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding diff --git a/torch_harmonics/examples/models/s2segformer.py b/torch_harmonics/examples/models/s2segformer.py index 7f96355b..6db85900 100644 --- a/torch_harmonics/examples/models/s2segformer.py +++ b/torch_harmonics/examples/models/s2segformer.py @@ -38,10 +38,8 @@ from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from torch_harmonics import AttentionS2, NeighborhoodAttentionS2 from torch_harmonics import ResampleS2 -from torch_harmonics import RealSHT, InverseRealSHT -from torch_harmonics.quadrature import _precompute_latitudes -from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath +from torch_harmonics.examples.models._layers import MLP, DropPath from functools import partial diff --git a/torch_harmonics/examples/models/s2transformer.py b/torch_harmonics/examples/models/s2transformer.py index 5c232c45..8c492de7 100644 --- a/torch_harmonics/examples/models/s2transformer.py +++ b/torch_harmonics/examples/models/s2transformer.py @@ -35,11 +35,10 @@ import torch.nn as nn import torch.amp as amp -from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 +from torch_harmonics import DiscreteContinuousConvS2 from torch_harmonics import NeighborhoodAttentionS2, AttentionS2 from torch_harmonics import ResampleS2 from torch_harmonics import RealSHT, InverseRealSHT -from torch_harmonics.quadrature import _precompute_latitudes from torch_harmonics.examples.models._layers import MLP, DropPath, LayerNorm, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding diff --git a/torch_harmonics/examples/models/s2unet.py b/torch_harmonics/examples/models/s2unet.py index 27d99f7b..b8ad39e8 100644 --- a/torch_harmonics/examples/models/s2unet.py +++ b/torch_harmonics/examples/models/s2unet.py @@ -36,12 +36,9 @@ import torch.amp as amp from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 -from torch_harmonics import NeighborhoodAttentionS2 from torch_harmonics import ResampleS2 -from torch_harmonics import RealSHT, InverseRealSHT -from torch_harmonics.quadrature import _precompute_latitudes -from torch_harmonics.examples.models._layers import MLP, DropPath +from torch_harmonics.examples.models._layers import DropPath from functools import partial diff --git a/torch_harmonics/examples/models/sfno.py b/torch_harmonics/examples/models/sfno.py index 72ed11fe..506076cc 100644 --- a/torch_harmonics/examples/models/sfno.py +++ b/torch_harmonics/examples/models/sfno.py @@ -37,8 +37,6 @@ from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding, DropPath -from functools import partial - class SphericalFourierNeuralOperatorBlock(nn.Module): """ diff --git a/torch_harmonics/examples/pde_sphere.py b/torch_harmonics/examples/pde_sphere.py index 18468ea0..1b84ece4 100644 --- a/torch_harmonics/examples/pde_sphere.py +++ b/torch_harmonics/examples/pde_sphere.py @@ -33,7 +33,7 @@ import torch import torch.nn as nn import torch_harmonics as th -from torch_harmonics.quadrature import _precompute_longitudes +from torch_harmonics.quadrature import precompute_longitudes import math import numpy as np @@ -97,7 +97,7 @@ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", rad # apply cosine transform and flip them lats = -torch.arcsin(cost) - lons = _precompute_longitudes(self.nlon) + lons = precompute_longitudes(self.nlon) self.lmax = self.sht.lmax self.mmax = self.sht.mmax diff --git a/torch_harmonics/examples/shallow_water_equations.py b/torch_harmonics/examples/shallow_water_equations.py index 73962d3b..f2c36a1d 100644 --- a/torch_harmonics/examples/shallow_water_equations.py +++ b/torch_harmonics/examples/shallow_water_equations.py @@ -33,7 +33,7 @@ import torch import torch.nn as nn import torch_harmonics as th -from torch_harmonics.quadrature import _precompute_longitudes +from torch_harmonics.quadrature import precompute_longitudes import math import numpy as np @@ -112,7 +112,7 @@ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", rad # apply cosine transform and flip them lats = -torch.arcsin(cost) - lons = _precompute_longitudes(self.nlon) + lons = precompute_longitudes(self.nlon) self.lmax = self.sht.lmax self.mmax = self.sht.mmax diff --git a/torch_harmonics/examples/stanford_2d3ds_dataset.py b/torch_harmonics/examples/stanford_2d3ds_dataset.py index 43f0f683..811d42ff 100644 --- a/torch_harmonics/examples/stanford_2d3ds_dataset.py +++ b/torch_harmonics/examples/stanford_2d3ds_dataset.py @@ -33,11 +33,11 @@ import math import torch -from torch.utils.data import Dataset, DataLoader, Subset +from torch.utils.data import Dataset, Subset import numpy as np -from torch_harmonics.quadrature import _precompute_latitudes +from torch_harmonics.quadrature import precompute_latitudes from torch_harmonics.examples.losses import get_quadrature_weights # some specifiers where to find the dataset @@ -304,7 +304,7 @@ def convert_dataset( # prepare computation of the class histogram class_histogram = np.zeros(num_classes) - _, quad_weights = _precompute_latitudes(nlat=img_shape[0], grid="equiangular") + _, quad_weights = precompute_latitudes(nlat=img_shape[0], grid="equiangular") quad_weights = quad_weights.reshape(-1, 1) * 2 * torch.pi / float(img_shape[1]) quad_weights = quad_weights.tile(1, img_shape[1]) quad_weights /= torch.sum(quad_weights) diff --git a/torch_harmonics/quadrature.py b/torch_harmonics/quadrature.py index e401b0a9..7824169a 100644 --- a/torch_harmonics/quadrature.py +++ b/torch_harmonics/quadrature.py @@ -35,10 +35,10 @@ import numpy as np import torch -def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0, +def _precompute_quadrature_weights(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]: """ - Precompute grid points and weights for various quadrature rules. + Precompute grid points and quadrature weights for various quadrature rules. Parameters ----------- @@ -82,17 +82,17 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa return xlg, wlg @lru_cache(typed=True, copy=True) -def _precompute_longitudes(nlon: int): +def precompute_longitudes(nlon: int): lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1] return lons @lru_cache(typed=True, copy=True) -def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]: +def precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]: # compute coordinates in the cosine theta domain - xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) + xlg, wlg = _precompute_quadrature_weights(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) # to perform the quadrature and account for the jacobian of the sphere, the quadrature rule # is formulated in the cosine theta domain, which is designed to integrate functions of cos theta diff --git a/torch_harmonics/resample.py b/torch_harmonics/resample.py index d1b2ffae..160735cd 100644 --- a/torch_harmonics/resample.py +++ b/torch_harmonics/resample.py @@ -15,7 +15,7 @@ # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. +# this softwrae without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE @@ -29,14 +29,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import List, Tuple, Union, Optional +from typing import Optional import math #import numpy as np import torch import torch.nn as nn -from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes class ResampleS2(nn.Module): @@ -90,10 +90,10 @@ def __init__( self.grid_out = grid_out # for upscaling the latitudes we will use interpolation - self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) - self.lons_in = _precompute_longitudes(nlon_in) - self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) - self.lons_out = _precompute_longitudes(nlon_out) + self.lats_in, _ = precompute_latitudes(nlat_in, grid=grid_in) + self.lons_in = precompute_longitudes(nlon_in) + self.lats_out, _ = precompute_latitudes(nlat_out, grid=grid_out) + self.lons_out = precompute_longitudes(nlon_out) # in the case where some points lie outside of the range spanned by lats_in, # we need to expand the solution to the poles before interpolating From 3ab5a6d3155b11a94cacbc41ab69bd40f0f3c5d7 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 5 Feb 2026 08:11:11 -0800 Subject: [PATCH 2/2] typo fix --- torch_harmonics/resample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_harmonics/resample.py b/torch_harmonics/resample.py index 160735cd..13133215 100644 --- a/torch_harmonics/resample.py +++ b/torch_harmonics/resample.py @@ -15,7 +15,7 @@ # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from -# this softwrae without specific prior written permission. +# this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE