Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@
import matplotlib.pyplot as plt

from torch_harmonics.examples import StanfordSegmentationDataset, Stanford2D3DSDownloader, StanfordDatasetSubset, compute_stats_s2
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.examples.losses import DiceLossS2, CrossEntropyLossS2, FocalLossS2
from torch_harmonics.examples.losses import CrossEntropyLossS2
from torch_harmonics.examples.metrics import IntersectionOverUnionS2, AccuracyS2
from torch_harmonics.plotting import plot_sphere, imshow_sphere

Expand Down
12 changes: 6 additions & 6 deletions notebooks/equivariance_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "40e4a61f",
"metadata": {},
"outputs": [],
Expand All @@ -35,7 +35,7 @@
"\n",
"\n",
"from torch_harmonics import AttentionS2\n",
"from torch_harmonics.quadrature import _precompute_latitudes\n",
"from torch_harmonics.quadrature import precompute_latitudes\n",
"from torch_harmonics.plotting import plot_sphere"
]
},
Expand Down Expand Up @@ -340,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "281ebfe8",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -415,7 +415,7 @@
"\n",
" # Assume quadrature weights are given or use Gaussian quadrature as done in torch-harmonics for accuracy\n",
" # quadrature_weights = np.ones((nlat, nlon)) * (4*np.pi) / (nlat*nlon) # Placeholder\n",
" _, quad_weights = _precompute_latitudes(nlat, grid=grid)\n",
" _, quad_weights = precompute_latitudes(nlat, grid=grid)\n",
" quad_weights = 2 * np.pi * quad_weights.reshape(-1, 1) / nlon\n",
"\n",
" # input signal is a spherical harmonic\n",
Expand Down Expand Up @@ -460,7 +460,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "16f492be",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -543,7 +543,7 @@
"\n",
" # Assume quadrature weights are given or use Gaussian quadrature as done in torch-harmonics for accuracy\n",
" # quadrature_weights = np.ones((nlat, nlon)) * (4*np.pi) / (nlat*nlon) # Placeholder\n",
" _, quad_weights = _precompute_latitudes(nlat, grid=grid)\n",
" _, quad_weights = precompute_latitudes(nlat, grid=grid)\n",
" quad_weights = 2 * np.pi * quad_weights.reshape(-1, 1) / nlon\n",
"\n",
" # input signal is a spherical harmonic\n",
Expand Down
12 changes: 6 additions & 6 deletions notebooks/resample_sphere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -22,7 +22,7 @@
"import torch.nn as nn\n",
"\n",
"from torch_harmonics import ResampleS2\n",
"from torch_harmonics.quadrature import _precompute_latitudes\n",
"from torch_harmonics.quadrature import precompute_latitudes\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from torch_harmonics.plotting import plot_sphere"
Expand All @@ -37,15 +37,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nlat = 257\n",
"nlon = 2*(nlat-1)\n",
"grid = \"equiangular\"\n",
"\n",
"xq, wq = _precompute_latitudes(nlat, grid=grid)"
"xq, wq = precompute_latitudes(nlat, grid=grid)"
]
},
{
Expand Down Expand Up @@ -351,7 +351,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "dace",
"language": "python",
"name": "python3"
},
Expand All @@ -365,7 +365,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from torch.library import opcheck
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2

from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes
from torch_harmonics.disco import cuda_kernels_is_available, optimized_kernels_is_available

from testutils import disable_tf32, set_seed, compare_tensors
Expand Down Expand Up @@ -120,12 +120,12 @@ def _precompute_convolution_tensor_dense(
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_in, win = precompute_latitudes(nlat_in, grid=grid_in)
lats_out, wout = precompute_latitudes(nlat_out, grid=grid_out)

# compute the phi differences.
lons_in = _precompute_longitudes(nlon_in)
lons_out = _precompute_longitudes(nlon_out)
lons_in = precompute_longitudes(nlon_in)
lons_out = precompute_longitudes(nlon_out)

# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
Expand Down
9 changes: 4 additions & 5 deletions torch_harmonics/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
from warnings import warn
from typing import Tuple, Union, Optional

import math

import torch
import torch.nn as nn

from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import precompute_latitudes
from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2
from torch_harmonics.attention._attention_utils import _neighborhood_s2_attention_torch, _neighborhood_s2_attention_optimized
from torch_harmonics.filter_basis import get_filter_basis
Expand Down Expand Up @@ -96,7 +95,7 @@ def __init__(
self.scale = scale

# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
_, wgl = precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
# we need to tile and flatten them accordingly
quad_weights = torch.tile(quad_weights.reshape(-1, 1), (1, self.nlon_in)).flatten()
Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(
raise ValueError("Error, theta_cutoff has to be positive.")

# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
_, wgl = precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * wgl.to(dtype=torch.float32) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False)

Expand Down
13 changes: 5 additions & 8 deletions torch_harmonics/disco/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,15 @@
#

import abc
from typing import List, Tuple, Union, Optional
from warnings import warn
from typing import Tuple, Union, Optional

import math

import torch
import torch.nn as nn

from functools import partial

from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes
from ._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from ._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
Expand Down Expand Up @@ -231,12 +228,12 @@ def _precompute_convolution_tensor_s2(
nlat_out, nlon_out = out_shape

# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_in, win = precompute_latitudes(nlat_in, grid=grid_in)
lats_out, wout = precompute_latitudes(nlat_out, grid=grid_out)

# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = _precompute_longitudes(nlon_in)
lons_in = precompute_longitudes(nlon_in)

# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
Expand Down
7 changes: 1 addition & 6 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

from typing import List, Tuple, Union, Optional
from typing import Tuple, Union, Optional
from itertools import accumulate

import torch
import torch.nn as nn

from functools import partial

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics.disco._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics.disco._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized
from torch_harmonics.filter_basis import get_filter_basis
from disco_helpers import optimized_kernels_is_available, preprocess_psi
from torch_harmonics.disco.convolution import (
_precompute_convolution_tensor_s2,
Expand Down
10 changes: 5 additions & 5 deletions torch_harmonics/distributed/distributed_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import torch
import torch.nn as nn

from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
from torch_harmonics.quadrature import precompute_latitudes, precompute_longitudes
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_azimuth_region, copy_to_azimuth_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
Expand Down Expand Up @@ -107,10 +107,10 @@ def __init__(
self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = _precompute_longitudes(nlon_out)
self.lats_in, _ = precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = precompute_longitudes(nlon_in)
self.lats_out, _ = precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = precompute_longitudes(nlon_out)

# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/distributed/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torch.amp import custom_fwd, custom_bwd

from .utils import polar_group, azimuth_group, polar_group_size
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
from .utils import is_distributed_polar, is_distributed_azimuth

# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
Expand Down
4 changes: 2 additions & 2 deletions torch_harmonics/examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from typing import Optional
from abc import ABC, abstractmethod

from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import precompute_latitudes


def get_quadrature_weights(nlat: int, nlon: int, grid: str, tile: bool = False, normalized: bool = True) -> torch.Tensor:
# area weights
_, q = _precompute_latitudes(nlat=nlat, grid=grid)
_, q = precompute_latitudes(nlat=nlat, grid=grid)
q = q.reshape(-1, 1) * 2 * torch.pi / nlon

# numerical precision can be an issue here, make sure it sums to 1:
Expand Down
1 change: 0 additions & 1 deletion torch_harmonics/examples/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import torch
import torch.nn as nn

from torch_harmonics.quadrature import _precompute_latitudes
from .losses import get_quadrature_weights


Expand Down
23 changes: 0 additions & 23 deletions torch_harmonics/examples/models/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,26 +596,3 @@ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_
else:
raise ValueError(f"Unknown learnable position embedding type {embed_type}")

# class SpiralPositionEmbedding(PositionEmbedding):
# """
# Returns position embeddings on the torus
# """

# def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

# super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

# with torch.no_grad():

# # alternating custom position embeddings
# lats, _ = _precompute_latitudes(img_shape[0], grid=grid)
# lats = lats.reshape(-1, 1)
# lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1]
# lons = lons.reshape(1, -1)

# # channel index
# k = torch.arange(self.num_chans).reshape(1, -1, 1, 1)
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))

# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
2 changes: 1 addition & 1 deletion torch_harmonics/examples/models/lsno.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch.amp as amp

from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import DiscreteContinuousConvS2
from torch_harmonics import ResampleS2

from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
Expand Down
4 changes: 1 addition & 3 deletions torch_harmonics/examples/models/s2segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes

from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath
from torch_harmonics.examples.models._layers import MLP, DropPath

from functools import partial

Expand Down
3 changes: 1 addition & 2 deletions torch_harmonics/examples/models/s2transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
import torch.nn as nn
import torch.amp as amp

from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import DiscreteContinuousConvS2
from torch_harmonics import NeighborhoodAttentionS2, AttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes

from torch_harmonics.examples.models._layers import MLP, DropPath, LayerNorm, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding

Expand Down
5 changes: 1 addition & 4 deletions torch_harmonics/examples/models/s2unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@
import torch.amp as amp

from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes

from torch_harmonics.examples.models._layers import MLP, DropPath
from torch_harmonics.examples.models._layers import DropPath

from functools import partial

Expand Down
2 changes: 0 additions & 2 deletions torch_harmonics/examples/models/sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@

from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding, DropPath

from functools import partial


class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Expand Down
4 changes: 2 additions & 2 deletions torch_harmonics/examples/pde_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import torch
import torch.nn as nn
import torch_harmonics as th
from torch_harmonics.quadrature import _precompute_longitudes
from torch_harmonics.quadrature import precompute_longitudes

import math
import numpy as np
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", rad

# apply cosine transform and flip them
lats = -torch.arcsin(cost)
lons = _precompute_longitudes(self.nlon)
lons = precompute_longitudes(self.nlon)

self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
Expand Down
Loading