Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ $ vim polaris/tasks/ocean/my_overflow/init.py

...

from polaris.ocean.model.eos import compute_density
from polaris.ocean.eos import compute_density
from polaris.ocean.vertical import init_vertical_coord


Expand Down
98 changes: 98 additions & 0 deletions polaris/ocean/eos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import xarray as xr

from polaris.config import PolarisConfigParser

from .linear import compute_linear_density
from .teos10 import compute_specvol as compute_teos10_specvol


def compute_density(
config: PolarisConfigParser,
temperature: xr.DataArray,
salinity: xr.DataArray,
pressure: xr.DataArray | None = None,
) -> xr.DataArray:
"""
Compute the density of seawater based on the equation of state specified
in the configuration.

Parameters
----------
config : polaris.config.PolarisConfigParser
Configuration object containing ocean parameters.

temperature : float or xarray.DataArray
Temperature (conservative, potential or in-situ) of the seawater.

salinity : float or xarray.DataArray
Salinity (practical or absolute) of the seawater.

pressure : float or xarray.DataArray, optional
Pressure (in-situ or reference) of the seawater.

Returns
-------
density : float or xarray.DataArray
Computed density (in-situ or reference) of the seawater.
"""
eos_type = config.get('ocean', 'eos_type')
if eos_type == 'linear':
density = compute_linear_density(config, temperature, salinity)
elif eos_type == 'teos-10':
if pressure is None:
raise ValueError(
'Pressure must be provided when using the TEOS-10 equation of '
'state.'
)
density = 1.0 / compute_teos10_specvol(
sa=salinity, ct=temperature, p=pressure
)
else:
raise ValueError(f'Unsupported equation of state type: {eos_type}')
return density


def compute_specvol(
config: PolarisConfigParser,
temperature: xr.DataArray,
salinity: xr.DataArray,
pressure: xr.DataArray | None = None,
) -> xr.DataArray:
"""
Compute the specific volume of seawater based on the equation of state
specified in the configuration.

Parameters
----------
config : polaris.config.PolarisConfigParser
Configuration object containing ocean parameters.

temperature : float or xarray.DataArray
Temperature (conservative, potential or in-situ) of the seawater.

salinity : float or xarray.DataArray
Salinity (practical or absolute) of the seawater.

pressure : float or xarray.DataArray, optional
Pressure (in-situ or reference) of the seawater.

Returns
-------
specvol : float or xarray.DataArray
Computed specific volume (in-situ or reference) of the seawater.
"""
eos_type = config.get('ocean', 'eos_type')
if eos_type == 'linear':
specvol = 1.0 / compute_linear_density(config, temperature, salinity)
elif eos_type == 'teos-10':
if pressure is None:
raise ValueError(
'Pressure must be provided when using the TEOS-10 equation of '
'state.'
)
specvol = compute_teos10_specvol(
sa=salinity, ct=temperature, p=pressure
)
else:
raise ValueError(f'Unsupported equation of state type: {eos_type}')
return specvol
47 changes: 47 additions & 0 deletions polaris/ocean/eos/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import xarray as xr

from polaris.config import PolarisConfigParser


def compute_linear_density(
config: PolarisConfigParser,
temperature: xr.DataArray,
salinity: xr.DataArray,
) -> xr.DataArray:
"""
Compute the density of seawater based on the the linear equation of state
with coefficients specified in the configuration. The distinction between
conservative, potential, and in-situ temperature or between absolute and
practical salinity is not relevant for the linear EOS.

Parameters
----------
config : polaris.config.PolarisConfigParser
Configuration object containing ocean parameters.

temperature : float or xarray.DataArray
Temperature of the seawater.

salinity : float or xarray.DataArray
Salinity of the seawater.

Returns
-------
density : float or xarray.DataArray
Computed density of the seawater.
"""
section = config['ocean']
alpha = section.getfloat('eos_linear_alpha')
beta = section.getfloat('eos_linear_beta')
rhoref = section.getfloat('eos_linear_rhoref')
Tref = section.getfloat('eos_linear_Tref')
Sref = section.getfloat('eos_linear_Sref')
assert (
alpha is not None
and beta is not None
and rhoref is not None
and Tref is not None
and Sref is not None
), 'All linear EOS parameters must be specified in the config options.'
density = rhoref + -alpha * (temperature - Tref) + beta * (salinity - Sref)
return density
63 changes: 63 additions & 0 deletions polaris/ocean/eos/teos10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import gsw
import xarray as xr


def compute_specvol(
sa: xr.DataArray, ct: xr.DataArray, p: xr.DataArray
) -> xr.DataArray:
"""
Compute specific volume from co-located p, CT and SA.

Notes
-----
- This function converts inputs to NumPy arrays and calls
``gsw.specvol`` directly for performance. Inputs must fit in
memory.

- Any parallelization should be handled by the caller (e.g., splitting
over outer dimensions and calling this function per chunk).

Parameters
----------
sa : xarray.DataArray
Absolute Salinity at the same points as p and ct.

ct : xarray.DataArray
Conservative Temperature at the same points as p and sa.

p : xarray.DataArray
Sea pressure in Pascals (Pa) at the same points as ct and sa.

Returns
-------
xarray.DataArray
Specific volume with the same dims/coords as ct and sa (m^3/kg).
"""

# Check sizes/dims match exactly
if not (p.sizes == ct.sizes == sa.sizes):
raise ValueError(
'p, ct and sa must have identical dimensions and sizes; '
f'got p={p.sizes}, ct={ct.sizes}, sa={sa.sizes}'
)

# Ensure coordinates align identically (names and labels)
p, ct, sa = xr.align(p, ct, sa, join='exact')

# Convert to NumPy and call gsw directly for performance
p_dbar = (p / 1.0e4).to_numpy()
ct_np = ct.to_numpy()
sa_np = sa.to_numpy()
specvol_np = gsw.specvol(sa_np, ct_np, p_dbar)

specvol = xr.DataArray(
specvol_np, dims=ct.dims, coords=ct.coords, name='specvol'
)

return specvol.assign_attrs(
{
'long_name': 'specific volume',
'units': 'm^3 kg^-1',
'standard_name': 'specific_volume',
}
)
16 changes: 0 additions & 16 deletions polaris/ocean/model/eos.py

This file was deleted.

8 changes: 7 additions & 1 deletion polaris/ocean/model/mpaso_to_omega.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ variables:
tracer3: Debug3

# state
layerThickness: LayerThickness
layerThickness: GeometricLayerThickness
normalVelocity: NormalVelocity

# auxiliary state
Expand All @@ -39,6 +39,12 @@ variables:
windStressMeridional: WindStressMeridional
relativeVorticity: RelVortVertex

# vertical coordinate
zMid: ZMid
zInterface: ZInterface
pressure: PressureMid


config:
- section:
time_management: TimeIntegration
Expand Down
103 changes: 67 additions & 36 deletions polaris/ocean/vertical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def init_vertical_coord(config, ds):

if coord_type == 'z-level':
init_z_level_vertical_coord(config, ds)
elif coord_type == 'z-star':
elif coord_type == 'z-star' or coord_type == 'z-tilde':
init_z_star_vertical_coord(config, ds)
elif coord_type == 'sigma':
init_sigma_vertical_coord(config, ds)
Expand All @@ -110,8 +110,11 @@ def init_vertical_coord(config, ds):
dim='Time', axis=0
)

ds['zMid'] = _compute_zmid_from_layer_thickness(
ds.layerThickness, ds.ssh, ds.cellMask
ds['zInterface'], ds['zMid'] = compute_zint_zmid_from_layer_thickness(
layer_thickness=ds.layerThickness,
bottom_depth=ds.bottomDepth,
min_level_cell=ds.minLevelCell,
max_level_cell=ds.maxLevelCell,
)

# fortran 1-based indexing
Expand Down Expand Up @@ -148,7 +151,7 @@ def update_layer_thickness(config, ds):

if coord_type == 'z-level':
update_z_level_layer_thickness(config, ds)
elif coord_type == 'z-star':
elif coord_type == 'z-star' or coord_type == 'z-tilde':
update_z_star_layer_thickness(config, ds)
elif coord_type == 'sigma':
update_sigma_layer_thickness(config, ds)
Expand All @@ -162,47 +165,75 @@ def update_layer_thickness(config, ds):
ds['layerThickness'] = ds.layerThickness.expand_dims(dim='Time', axis=0)


def _compute_cell_mask(minLevelCell, maxLevelCell, nVertLevels):
cellMask = []
for zIndex in range(nVertLevels):
mask = np.logical_and(zIndex >= minLevelCell, zIndex <= maxLevelCell)
cellMask.append(mask)
cellMaskArray = xr.DataArray(cellMask, dims=['nVertLevels', 'nCells'])
cellMaskArray = cellMaskArray.transpose('nCells', 'nVertLevels')
return cellMaskArray


def _compute_zmid_from_layer_thickness(layerThickness, ssh, cellMask):
def compute_zint_zmid_from_layer_thickness(
layer_thickness: xr.DataArray,
bottom_depth: xr.DataArray,
min_level_cell: xr.DataArray,
max_level_cell: xr.DataArray,
) -> tuple[xr.DataArray, xr.DataArray]:
"""
Compute zMid from ssh and layerThickness for any vertical coordinate
Compute height z at layer interfaces and midpoints given layer thicknesses
and bottom depth.

Parameters
----------
layerThickness : xarray.DataArray
The thickness of each layer
layer_thickness : xarray.DataArray
The layer thickness of each layer.

bottom_depth : xarray.DataArray
The positive-down depth of the seafloor.

ssh : xarray.DataArray
The sea surface height
min_level_cell : xarray.DataArray
The zero-based minimum vertical index from each column.

cellMask : xarray.DataArray
A boolean mask of where there are valid cells
max_level_cell : xarray.DataArray
The zero-based maximum vertical index from each column.

Returns
-------
zMid : xarray.DataArray
The elevation of layer centers
z_interface : xarray.DataArray
The elevation of layer interfaces.

z_mid : xarray.DataArray
The elevation of layer midpoints.
"""

zTop = ssh.copy()
nVertLevels = layerThickness.sizes['nVertLevels']
zMid = []
n_vert_levels = layer_thickness.sizes['nVertLevels']

z_bot = -bottom_depth
k = n_vert_levels
mask_bot = np.logical_and(k >= min_level_cell, k - 1 <= max_level_cell)
z_interface_list = [z_bot.where(mask_bot)]
z_mid_list = []

for k in range(n_vert_levels - 1, -1, -1):
dz = layer_thickness.isel(nVertLevels=k)
mask_mid = np.logical_and(k >= min_level_cell, k <= max_level_cell)
mask_top = np.logical_and(k >= min_level_cell, k - 1 <= max_level_cell)
dz = dz.where(mask_mid, 0.0)
z_top = z_bot + dz
z_interface_list.append(z_top.where(mask_top))
z_mid = (z_bot + 0.5 * dz).where(mask_mid)
z_mid_list.append(z_mid)
z_bot = z_top

dims = list(layer_thickness.dims)
interface_dims = list(dims) + ['nVertLevelsP1']
interface_dims.remove('nVertLevels')

z_interface = xr.concat(
reversed(z_interface_list), dim='nVertLevelsP1'
).transpose(*interface_dims)
z_mid = xr.concat(reversed(z_mid_list), dim='nVertLevels').transpose(*dims)

return z_interface, z_mid


def _compute_cell_mask(minLevelCell, maxLevelCell, nVertLevels):
cellMask = []
for zIndex in range(nVertLevels):
mask = cellMask.isel(nVertLevels=zIndex)
thickness = layerThickness.isel(nVertLevels=zIndex).where(mask, 0.0)
z = (zTop - 0.5 * thickness).where(mask)
zMid.append(z)
zTop -= thickness
zMid = xr.concat(zMid, dim='nVertLevels').transpose(
'Time', 'nCells', 'nVertLevels'
)
return zMid
mask = np.logical_and(zIndex >= minLevelCell, zIndex <= maxLevelCell)
cellMask.append(mask)
cellMaskArray = xr.DataArray(cellMask, dims=['nVertLevels', 'nCells'])
cellMaskArray = cellMaskArray.transpose('nCells', 'nVertLevels')
return cellMaskArray
Loading
Loading