Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6,191 changes: 5,828 additions & 363 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ platforms = ["linux-64", "osx-arm64", "osx-64"]
gdal = ">=3.0.0"
netcdf4 = "<1.7.1"
scikit-image = ">=0.24.0,<0.27"
matplotlib = ">=3.9.4,<4"
ipykernel = ">=6.30.1,<8"
scikit-learn = ">=1.6.1,<2"

[tool.pixi.pypi-dependencies]
spectral-util = { path = ".", editable = true }
Expand Down
5 changes: 3 additions & 2 deletions spectral_util/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Centralized CLI interface for SpectralUtil."""

import click
from spectral_util.common import quicklooks
from spectral_util import common
from spectral_util.mosaic import mosaic
from spectral_util.ea_assist import earthaccess_helpers_AV3, earthaccess_helpers_EMIT

Expand All @@ -13,7 +13,8 @@ def cli():
cli.add_command(earthaccess_helpers_AV3.find_download_and_combine, name='av3-download')
cli.add_command(earthaccess_helpers_EMIT.find_download_and_combine_EMIT, name='emit-download')
cli.add_command(mosaic.cli, name='mosaic')
cli.add_command(quicklooks.cli, name='quicklooks')
cli.add_command(common.cli_quicklook, name='quicklooks')
cli.add_command(common.cli_plot, name='plot')

if __name__ == '__main__':
cli()
19 changes: 18 additions & 1 deletion spectral_util/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
# common module initialization
# common module initialization

import click
from .quicklooks import ndvi, rgb, nbr
from .plotting import plot_basic_overview

@click.group()
def cli_quicklook():
pass

@click.group()
def cli_plot():
pass

cli_quicklook.add_command(rgb)
cli_quicklook.add_command(nbr)
cli_quicklook.add_command(ndvi)
cli_plot.add_command(plot_basic_overview)
103 changes: 103 additions & 0 deletions spectral_util/common/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3
"""Simple plotting routines."""

import click
import numpy as np
from spectral_util.spec_io import load_data
from spectral_util.common.quicklooks import get_rgb
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans

def plot_spectra(input_file, n_points=10, method='even', seed=13):
"""
Plots an RGB image on the left and selected spectra on the right.

Args:
input_file (str): Path to the input spectral file.
n_points (int): Number of spectra to plot.
method (str): 'random' or 'kmeans'.
"""
# Load data
meta, data = load_data(input_file)
rgb = get_rgb(data, meta)

# Select Spectra
spectra_to_plot = []
labels = []

# Remove nodata/NaNs for clustering/sampling if possible
# Simple valid mask:
valid_mask = np.logical_not(np.all(rgb == 0, axis=-1))
if np.sum(valid_mask) == 0:
print("No valid data found.")
return

np.random.seed(seed)
indices = None
if method == 'kmeans':
kmeans = MiniBatchKMeans(n_clusters=n_points, n_init=3, batch_size=1024).fit(np.array(data)[valid_mask,:])
spectra_to_plot = kmeans.cluster_centers_
labels = [f"Cluster {i}" for i in range(n_points)]
elif method == 'random':
# Random
indices = np.where(valid_mask)
perm = np.random.permutation(len(indices[0]))[:n_points]
indices = tuple(idx[perm] for idx in indices)

# inefficient to cast this, but the netcdf subsetting is odd, should re-examine
spectra_to_plot = np.array(data)[indices[0],indices[1],:]
labels = [f"Point {i}" for i in range(n_points)]
elif method == 'even':
indices = [np.linspace(0, data.shape[0], n_points+2, dtype=int)[1:-1],
np.linspace(0, data.shape[0], n_points+2, dtype=int)[1:-1]]
spectra_to_plot = np.array(data)[indices[0],indices[1],:]
labels = [f"Point {i}" for i in range(n_points)]

# Plotting
cmap = plt.get_cmap('Dark2')

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: RGB
axes[0].imshow(rgb)
axes[0].set_title(f"RGB Preview")
if indices is not None:
for _i, (y, x) in enumerate(zip(indices[0], indices[1])):
axes[0].plot(x, y, 'o', markerfacecolor='none', markeredgecolor=cmap(_i % cmap.N),
markersize=10, markeredgewidth=2, label='Selected Points')
axes[0].axis('off')

# Right: Spectra
axes[1].set_xlabel("Wavelength")

wl_nan = meta.wavelengths.copy()
print(spectra_to_plot[0,:])
wl_nan[np.isclose(spectra_to_plot[0,:], -0.01, atol=1e-5)] = np.nan
for i, spec in enumerate(spectra_to_plot):
axes[1].plot(wl_nan, spec, label=labels[i], c=cmap(i % cmap.N))

axes[1].set_title(f"{method.capitalize()} Spectra")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

@click.command()
@click.argument('input_file', type=click.Path(exists=True))
@click.option('--output_file', '-o', default=None, help='Path to save the plot (if not provided, will display instead)')
@click.option('--n_points', '-n', default=5, help='Number of spectra to plot')
@click.option('--method', '-m', type=click.Choice(['random', 'kmeans', 'even']), default='even', help='Selection method')
def plot_basic_overview(input_file, output_file, n_points, method):
"""
Visualizes an image and consistent spectra.
"""
# Click passes None if bands is not provided, handled in function
plot_spectra(input_file, n_points, method)
if output_file:
plt.savefig(output_file, bbox_inches='tight', dpi=300)
click.echo(f"Plot saved to {output_file}")

if __name__ == '__main__':
plot_basic_overview()

125 changes: 73 additions & 52 deletions spectral_util/common/quicklooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@ def shared_options(f):
f = click.option('--ortho', is_flag=True, help='Orthorectify the output; only relevant if the input format is non-orthod')
return f

def calc_index(data, meta, first_wl, second_wl, first_width=0, second_width=0, nodata=-9999):
"""
Calculate a spectral index.
Args:
data (numpy like): Input data array.
meta (Metadata): Metadata object containing wavelength information.
first_wl (int): Wavelength for the first band [nm].
second_wl (int): Wavelength for the second band [nm].
first_width (int): Width for the first band [nm]; 0 = single wavelength.
second_width (int): Width for the second band [nm]; 0 = single wavelength.
"""
first = data[..., meta.wl_index(first_wl, first_width)]
second = data[..., meta.wl_index(second_wl, second_width)]
if len(first.shape) == 3:
first = np.mean(first, axis=-1)
if len(second.shape) == 3:
second = np.mean(second, axis=-1)

index = (first - second) / (first + second)
index = index.squeeze()
index[first == meta.nodata_value] = nodata
index[np.isfinite(index) == False] = nodata

return index


@click.command()
@common_arguments
@click.option('--red_wl', default=660, help='Red band wavelength [nm]')
Expand All @@ -37,16 +63,8 @@ def ndvi(input_file, output_file, ortho, red_wl, nir_wl, red_width, nir_width):
"""
click.echo(f"Running NDVI Calculation on {input_file}")
meta, rfl = load_data(input_file, lazy=True, load_glt=ortho)

red = rfl[..., meta.wl_index(red_wl, red_width)]
nir = rfl[..., meta.wl_index(nir_wl, nir_width)]

ndvi = (nir - red) / (nir + red)
ndvi = ndvi.squeeze()
ndvi[nir == meta.nodata_value] = -9999
ndvi[np.isfinite(ndvi) == False] = -9999
ndvi = calc_index(rfl, meta, red_wl, nir_wl, red_width, nir_width)
ndvi = ndvi.reshape((ndvi.shape[0], ndvi.shape[1], 1))

write_cog(output_file, ndvi, meta, ortho=ortho)

@click.command()
Expand All @@ -71,50 +89,34 @@ def nbr(input_file, output_file, ortho, nir_wl, swir_wl, nir_width, swir_width):

click.echo(f"Running NBR Calculation on {input_file}")
meta, rfl = load_data(input_file, lazy=True, load_glt=ortho)

nir = rfl[..., meta.wl_index(nir_wl)]
swir = rfl[..., meta.wl_index(swir_wl)]

nbr = (nir - swir) / (swir + nir)
nbr = nbr.squeeze().astype(np.float32)
nbr[nir == meta.nodata_value] = -9999
nbr[np.isfinite(nbr) == False] = -9999
nbr = calc_index(rfl, meta, nir_wl, swir_wl, nir_width, swir_width)
nbr = nbr.reshape((nbr.shape[0], nbr.shape[1], 1))

write_cog(output_file, nbr, meta, ortho=ortho, nodata_value=-9999)

@click.command()
@common_arguments
@click.option('--red_wl', default=650, help='Red band wavelength [nm]')
@click.option('--green_wl', default=560, help='Green band wavelength [nm]')
@click.option('--blue_wl', default=460, help='Blue band width [nm]')
@click.option('--stretch', default=[2,98], nargs=2, type=int, help='stretch the rgb; set to -1 -1 to not stretch')
@click.option('--scale', default=[-1,-1,-1,-1,-1,-1], nargs=6, type=float, help='scale the rgb to these min, max pairs')
def rgb(input_file, output_file, ortho, red_wl, green_wl, blue_wl, stretch, scale):

def get_rgb(rfl, meta, red_wl=650, green_wl=560, blue_wl=460, percentile_stretch=[2,98], scale=[-1,-1,-1,-1,-1,-1]):
"""
Calculate RGB composite.
Get RGB composite from reflectance data.

Args:
input_file (str): Path to the input file.
output_file (str): Path to the output file.
ortho (bool): Orthorectify the output.
rfl (numpy like): Reflectance data array.
meta (Metadata): Metadata object containing wavelength information.
red_wl (int): Red band wavelength [nm].
green_wl (int): Green band wavelength [nm].
blue_wl (int): Blue band wavelength [nm].
stretch [(int), (int)]: Stretch the RGB values to the percentile min & max listed here. Set to -1, -1 to not stretch.
percentile_stretch [(int), (int)]: Stretch the RGB values to the percentile min & max listed here. Set to -1, -1 to not stretch.
scale [(int), (int), (int), (int), (int), (int)]: Scale the RGB values to the min & max listed here. Set to -1s to not scale (default).
"""
if np.all(np.array(scale) != -1) and np.all(np.array(stretch) != -1):
raise ValueError("Cannot set both stretch and scale")

click.echo(f"Running RGB Calculation on {input_file}")
meta, rfl = load_data(input_file, lazy=True, load_glt=ortho)
Returns:
numpy array: RGB composite array.

IF percentile_stretch or scale is used, return is a uint8, otherwise it is the same dtype as the input reflectance data.
"""
rgb = rfl[..., np.array([meta.wl_index(x) for x in [red_wl, green_wl, blue_wl]])]
if stretch[0] != -1 and stretch[1] != -1:
if percentile_stretch[0] != -1 and percentile_stretch[1] != -1:
rgb[rgb == meta.nodata_value] = np.nan
rgb -= np.nanpercentile(rgb, stretch[0], axis=(0, 1))
rgb /= np.nanpercentile(rgb, stretch[1], axis=(0, 1))
rgb -= np.nanpercentile(rgb, percentile_stretch[0], axis=(0, 1))
rgb /= np.nanpercentile(rgb, percentile_stretch[1], axis=(0, 1))
rgb[rgb < 0] = 0
rgb[rgb > 1] = 1
mask = np.isfinite(rgb[...,0]) == False
Expand All @@ -123,7 +125,6 @@ def rgb(input_file, output_file, ortho, red_wl, green_wl, blue_wl, stretch, scal

rgb[rgb == 0] = 1
rgb[mask,:] = 0
nodata_value = 0
elif np.all(np.array(scale) != -1):
mask = rgb[...,0] == meta.nodata_value
rgb[...,0] = (np.clip(rgb[...,0], scale[0], scale[1]) - scale[0]) / (scale[1] - scale[0])
Expand All @@ -134,19 +135,39 @@ def rgb(input_file, output_file, ortho, red_wl, green_wl, blue_wl, stretch, scal
rgb[rgb == 0] = 1
rgb[mask,:] = 0

nodata_value = 0
else:
nodata_value = meta.nodata_value
return rgb

write_cog(output_file, rgb, meta, ortho=ortho, nodata_value=nodata_value)
@click.command()
@common_arguments
@click.option('--red_wl', default=650, help='Red band wavelength [nm]')
@click.option('--green_wl', default=560, help='Green band wavelength [nm]')
@click.option('--blue_wl', default=460, help='Blue band width [nm]')
@click.option('--stretch', default=[2,98], nargs=2, type=int, help='stretch the rgb; set to -1 -1 to not stretch')
@click.option('--scale', default=[-1,-1,-1,-1,-1,-1], nargs=6, type=float, help='scale the rgb to these min, max pairs')
def rgb(input_file, output_file, ortho, red_wl, green_wl, blue_wl, stretch, scale):
"""
Calculate RGB composite.

Args:
input_file (str): Path to the input file.
output_file (str): Path to the output file.
ortho (bool): Orthorectify the output.
red_wl (int): Red band wavelength [nm].
green_wl (int): Green band wavelength [nm].
blue_wl (int): Blue band wavelength [nm].
stretch [(int), (int)]: Stretch the RGB values to the percentile min & max listed here. Set to -1, -1 to not stretch.
scale [(int), (int), (int), (int), (int), (int)]: Scale the RGB values to the min & max listed here. Set to -1s to not scale (default).
"""
if np.all(np.array(scale) != -1) and np.all(np.array(stretch) != -1):
raise ValueError("Cannot set both stretch and scale")

click.echo(f"Running RGB Calculation on {input_file}")
meta, rfl = load_data(input_file, lazy=True, load_glt=ortho)
rgb = get_rgb(rfl, meta, red_wl, green_wl, blue_wl, stretch, scale)

@click.group()
def cli():
pass
nodata_value = meta.nodata_value
if np.all(np.array(stretch) != -1) or np.all(np.array(scale) != -1):
nodata_value = 0

cli.add_command(rgb)
cli.add_command(nbr)
cli.add_command(ndvi)
write_cog(output_file, rgb, meta, ortho=ortho, nodata_value=nodata_value)

if __name__ == '__main__':
cli()
41 changes: 40 additions & 1 deletion spectral_util/spec_io/spec_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,15 @@ def open_netcdf(input_file, lazy=True, load_glt=False, load_loc=False, mask_type
- numpy.ndarray or netCDF4.Variable: The data, either as a lazy-loaded variable or a fully loaded numpy array.
"""
input_filename = os.path.basename(input_file)
if 'EMIT' in input_filename and 'RAD' in input_filename:
if 'emit' in input_filename.lower() and ('rad' in input_filename.lower() or 'rdn' in input_filename.lower()):
if return_loc_from_l1b_rad_nc:
return open_loc_l1b_rad_nc(input_file, lazy=lazy, load_glt=load_glt)
else:
return open_emit_rdn(input_file, lazy=lazy, load_glt=load_glt)

if 'emit' in input_filename.lower() and 'rfl' in input_filename.lower():
return open_emit_rfl(input_file, lazy=lazy, load_glt=load_glt)

elif ('emit' in input_filename.lower() and 'obs' in input_filename.lower()):
return open_emit_obs_nc(input_file, lazy=lazy, load_glt=load_glt, load_loc=load_loc)
elif ('emit' in input_filename.lower() and 'l2a_mask' in input_filename.lower()):
Expand All @@ -345,6 +349,41 @@ def open_netcdf(input_file, lazy=True, load_glt=False, load_loc=False, mask_type
raise ValueError(f'Unknown file type for {input_file}')


def open_emit_rfl(input_file, lazy=True, load_glt=False):
"""
Opens an EMIT reflectance NetCDF file and extracts the spectral metadata and reflectance data.

Args:
input_file (str): Path to the NetCDF file.
lazy (bool, optional): If True, loads the reflectance data lazily. Defaults to True.
load_glt (bool, optional): If True, loads the glt for orthoing. Defaults to False.

Returns:
tuple: A tuple containing:
- SpectralMetadata: An object containing the wavelengths and FWHM.
- numpy.ndarray or netCDF4.Variable: The reflectance data, either as a lazy-loaded variable or a fully loaded numpy array.
"""
ds = nc.Dataset(input_file)
wl = ds['sensor_band_parameters']['wavelengths'][:]
fwhm = ds['sensor_band_parameters']['fwhm'][:]
trans = ds.geotransform
proj = ds.spatial_ref
nodata_value = float(ds['reflectance']._FillValue)

if lazy:
rdn = ds['reflectance']
else:
rdn = np.array(ds['reflectance'][:])

glt = None
if load_glt:
glt = np.stack([ds['location']['glt_x'][:],ds['location']['glt_y'][:]],axis=-1)

meta = SpectralMetadata(wl, fwhm, trans, proj, glt, pre_orthod=False, nodata_value=nodata_value)

return meta, rdn


def open_emit_rdn(input_file, lazy=True, load_glt=False):
"""
Opens an EMIT radiance NetCDF file and extracts the spectral metadata and radiance data.
Expand Down
Loading