Source code for skyborn.windspharm.xarray

"""
Spherical harmonic vector wind computations with xarray interface.

This module provides a VectorWind class that works with xarray DataArrays,
preserving coordinate information and metadata throughout the computation process.
It serves as a high-level interface to the standard VectorWind implementation.

Main Class:
    VectorWind: xarray-aware interface for wind field analysis

Example:
    >>> import xarray as xr
    >>> from skyborn.windspharm.xarray import VectorWind
    >>>
    >>> # Load wind data as xarray DataArrays
    >>> u = xr.open_dataarray('u_wind.nc')
    >>> v = xr.open_dataarray('v_wind.nc')
    >>>
    >>> # Create VectorWind instance
    >>> vw = VectorWind(u, v)
    >>>
    >>> # Compute with preserved metadata
    >>> vorticity = vw.vorticity()
    >>> streamfunction = vw.streamfunction()
"""

from __future__ import annotations

import warnings
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np
import xarray as xr

__all__ = ["VectorWind", "ReducedVectorWind"]

from . import reduced, standard
from ._common import get_apiorder, inspect_gridtype

# Type aliases for better readability
DataArray = xr.DataArray
LegFunc = str  # 'stored' or 'computed'


[docs] class VectorWind: """ Vector wind analysis using xarray DataArrays. This class provides a high-level interface for spherical harmonic wind analysis that preserves xarray coordinate information and metadata. It wraps the standard VectorWind implementation while maintaining CF-compliant attributes. Parameters ---------- u, v : xarray.DataArray Zonal and meridional wind components. Must have the same dimensions, coordinates, and contain no missing values. Should include latitude and longitude dimensions with appropriate coordinate information. rsphere : float, default 6.3712e6 Earth radius in meters for spherical harmonic computations. legfunc : {'stored', 'computed'}, default 'stored' Legendre function computation method: - 'stored': precompute and store (faster, more memory) - 'computed': compute on-the-fly (slower, less memory) gridtype : {'regular', 'gaussian'}, optional Explicit grid type override. If omitted, the grid type is inferred from latitude coordinates. Attributes ---------- _api : standard.VectorWind Underlying standard VectorWind instance _reorder : tuple Original dimension ordering for output reconstruction _ishape : tuple Original data shape _coords : list Original coordinate information Examples -------- >>> import xarray as xr >>> from skyborn.windspharm.xarray import VectorWind >>> >>> # Load wind components >>> u = xr.open_dataarray('u850.nc') >>> v = xr.open_dataarray('v850.nc') >>> >>> # Create VectorWind instance >>> vw = VectorWind(u, v) >>> >>> # Compute vorticity with preserved metadata >>> vorticity = vw.vorticity() >>> print(vorticity.attrs) # CF-compliant attributes >>> >>> # Helmholtz decomposition >>> u_chi, v_chi, u_psi, v_psi = vw.helmholtz() """
[docs] def __init__( self, u: DataArray, v: DataArray, rsphere: float = 6.3712e6, legfunc: LegFunc = "stored", gridtype: Optional[str] = None, precision: Literal["auto", "single", "double"] = "auto", ) -> None: """Initialize VectorWind instance with comprehensive validation.""" # Validate input types if not isinstance(u, xr.DataArray): raise TypeError(f"u must be xarray.DataArray, got {type(u).__name__}") if not isinstance(v, xr.DataArray): raise TypeError(f"v must be xarray.DataArray, got {type(v).__name__}") # Validate coordinate compatibility self._validate_coordinates(u, v) # Find and validate latitude/longitude coordinates lat, lat_dim = _find_latitude_coordinate(u) lon, lon_dim = _find_longitude_coordinate(u) # Ensure north-to-south latitude ordering if lat.values[0] < lat.values[1]: u = _reverse(u, lat_dim) v = _reverse(v, lat_dim) lat, lat_dim = _find_latitude_coordinate(u) # Determine grid type, unless the caller needs to override detection for # known global grids with latitude conventions outside our strict checks. if gridtype is None: gridtype = inspect_gridtype(lat.values) else: gridtype = gridtype.lower() if gridtype not in ("regular", "gaussian"): raise ValueError( f"Invalid grid type: '{gridtype}'. Must be 'regular' or 'gaussian'" ) # Prepare data for standard API apiorder, _ = get_apiorder(u.ndim, lat_dim, lon_dim) apiorder = [u.dims[i] for i in apiorder] # Store original structure for output reconstruction self._reorder = u.dims # Reorder dimensions and prepare data u = u.transpose(*apiorder) v = v.transpose(*apiorder) # Store shape and coordinates for reconstruction self._ishape = u.shape self._coords = [u.coords[name] for name in u.dims] # The standard API now accepts (lat, lon, *extra_dims) directly. u_data = u.values v_data = v.values self._u_component_dtype = standard.VectorWind._infer_output_dtype(u_data) self._v_component_dtype = standard.VectorWind._infer_output_dtype(v_data) self._api = standard.VectorWind._from_owned_input( u_data, v_data, gridtype=gridtype, rsphere=rsphere, legfunc=legfunc, precision=precision, )
def _validate_coordinates(self, u: DataArray, v: DataArray) -> None: """ Validate that u and v have compatible coordinates. Parameters ---------- u, v : DataArray Wind components to validate Raises ------ ValueError If dimensions or coordinate values don't match """ # Check dimension names if u.dims != v.dims: raise ValueError( f"u and v must have identical dimensions. " f"Got u: {u.dims}, v: {v.dims}" ) # Check coordinate values u_coords = [u.coords[name].values for name in u.dims] v_coords = [v.coords[name].values for name in v.dims] mismatched_coords = [] for i, (uc, vc) in enumerate(zip(u_coords, v_coords)): try: if not (uc == vc).all(): mismatched_coords.append(u.dims[i]) except (ValueError, TypeError): # Handle different shapes or types mismatched_coords.append(u.dims[i]) if mismatched_coords: raise ValueError( f"u and v must have identical coordinate values. " f"Mismatched coordinates: {mismatched_coords}" ) def _metadata(self, data: Any, name: str, **attributes: Any) -> DataArray: """ Create DataArray with proper metadata and coordinate information. Parameters ---------- data : array_like Data to wrap in DataArray name : str Variable name **attributes Additional attributes to set Returns ------- DataArray Properly formatted DataArray with coordinates and metadata """ # Reshape to original structure data = data.reshape(self._ishape) # Create DataArray with coordinates result = xr.DataArray(data, coords=self._coords, name=name) # Restore original dimension order result = result.transpose(*self._reorder) # Set attributes for attr, value in attributes.items(): result.attrs[attr] = value return result def _component_metadata( self, data: Any, dtype: np.dtype, name: str, **attributes: Any ) -> DataArray: """Wrap stored wind components using the public component dtype.""" restored = self._api._restore_output_dtype(data, dtype) return self._metadata(restored, name, **attributes)
[docs] def u(self) -> DataArray: """ Get zonal component of vector wind. Returns ------- DataArray Zonal (eastward) wind component with CF-compliant attributes Examples -------- >>> u_wind = vw.u() >>> print(u_wind.attrs['standard_name']) # 'eastward_wind' """ return self._component_metadata( self._api.u, self._u_component_dtype, "u", units="m s**-1", standard_name="eastward_wind", long_name="eastward_component_of_wind", )
[docs] def v(self) -> DataArray: """ Get meridional component of vector wind. Returns ------- DataArray Meridional (northward) wind component with CF-compliant attributes Examples -------- >>> v_wind = vw.v() >>> print(v_wind.attrs['standard_name']) # 'northward_wind' """ return self._component_metadata( self._api.v, self._v_component_dtype, "v", units="m s**-1", standard_name="northward_wind", long_name="northward_component_of_wind", )
[docs] def magnitude(self) -> DataArray: """ Calculate wind speed (magnitude of vector wind). Returns ------- DataArray Wind speed with CF-compliant attributes Examples -------- >>> wind_speed = vw.magnitude() >>> print(wind_speed.attrs['standard_name']) # 'wind_speed' """ magnitude = self._api.magnitude() return self._metadata( magnitude, "speed", units="m s**-1", standard_name="wind_speed", long_name="wind_speed", )
[docs] def vrtdiv(self, truncation: Optional[int] = None) -> Tuple[DataArray, DataArray]: """ Calculate relative vorticity and horizontal divergence. Parameters ---------- truncation : int, optional Triangular truncation limit for spherical harmonic computation Returns ------- vorticity : DataArray Relative vorticity with CF-compliant attributes divergence : DataArray Horizontal divergence with CF-compliant attributes See Also -------- vorticity : Calculate only vorticity divergence : Calculate only divergence Examples -------- >>> vrt, div = vw.vrtdiv() >>> vrt_t13, div_t13 = vw.vrtdiv(truncation=13) """ vrt, div = self._api.vrtdiv(truncation=truncation) vrt_da = self._metadata( vrt, "vorticity", units="s**-1", standard_name="atmosphere_relative_vorticity", long_name="relative_vorticity", ) div_da = self._metadata( div, "divergence", units="s**-1", standard_name="divergence_of_wind", long_name="horizontal_divergence", ) return vrt_da, div_da
[docs] def vorticity(self, truncation: Optional[int] = None) -> DataArray: """ Calculate relative vorticity. Parameters ---------- truncation : int, optional Triangular truncation limit for spherical harmonic computation Returns ------- DataArray Relative vorticity field with CF-compliant attributes See Also -------- vrtdiv : Calculate both vorticity and divergence absolutevorticity : Calculate absolute vorticity Examples -------- >>> vrt = vw.vorticity() >>> vrt_t13 = vw.vorticity(truncation=13) """ vrt = self._api.vorticity(truncation=truncation) return self._metadata( vrt, "vorticity", units="s**-1", standard_name="atmosphere_relative_vorticity", long_name="relative_vorticity", )
[docs] def divergence(self, truncation: Optional[int] = None) -> DataArray: """ Calculate horizontal divergence. Parameters ---------- truncation : int, optional Triangular truncation limit for spherical harmonic computation Returns ------- DataArray Horizontal divergence field with CF-compliant attributes See Also -------- vrtdiv : Calculate both vorticity and divergence Examples -------- >>> div = vw.divergence() >>> div_t13 = vw.divergence(truncation=13) """ div = self._api.divergence(truncation=truncation) return self._metadata( div, "divergence", units="s**-1", standard_name="divergence_of_wind", long_name="horizontal_divergence", )
[docs] def planetaryvorticity(self, omega: Optional[float] = None) -> DataArray: """ Calculate planetary vorticity (Coriolis parameter). Parameters ---------- omega : float, optional Earth's angular velocity in rad/s. Default is 7.292e-5 s⁻¹ Returns ------- DataArray Planetary vorticity (Coriolis parameter) with CF-compliant attributes See Also -------- absolutevorticity : Calculate absolute vorticity Examples -------- >>> f = vw.planetaryvorticity() >>> f_custom = vw.planetaryvorticity(omega=7.2921150e-5) """ f = self._api.planetaryvorticity(omega=omega) return self._metadata( f, "coriolis", units="s**-1", standard_name="coriolis_parameter", long_name="planetary_vorticity", )
[docs] def absolutevorticity(self, omega=None, truncation=None): """Absolute vorticity (sum of relative and planetary vorticity). **Optional arguments:** *omega* Earth's angular velocity. The default value if not specified is 7.292x10**-5 s**-1. *truncation* Truncation limit (triangular truncation) for the spherical harmonic computation. **Returns:** *avorticity* The absolute (relative + planetary) vorticity. **See also:** `~VectorWind.vorticity`, `~VectorWind.planetaryvorticity`. **Examples:** Compute absolute vorticity:: avrt = w.absolutevorticity() Compute absolute vorticity and apply spectral truncation at triangular T13, also override the default value for Earth's angular velocity:: avrt = w.absolutevorticity(omega=7.2921150, truncation=13) """ avrt = self._api.absolutevorticity(omega=omega, truncation=truncation) avrt = self._metadata( avrt, "absolute_vorticity", units="s**-1", standard_name="atmosphere_absolute_vorticity", long_name="absolute_vorticity", ) return avrt
[docs] def sfvp(self, truncation: Optional[int] = None) -> Tuple[DataArray, DataArray]: """ Calculate streamfunction and velocity potential. Parameters ---------- truncation : int, optional Triangular truncation limit for spherical harmonic computation Returns ------- streamfunction : DataArray Streamfunction field with CF-compliant attributes velocity_potential : DataArray Velocity potential field with CF-compliant attributes See Also -------- streamfunction : Calculate only streamfunction velocitypotential : Calculate only velocity potential Examples -------- >>> psi, chi = vw.sfvp() >>> psi_t13, chi_t13 = vw.sfvp(truncation=13) """ sf, vp = self._api.sfvp(truncation=truncation) sf_da = self._metadata( sf, "streamfunction", units="m**2 s**-1", standard_name="atmosphere_horizontal_streamfunction", long_name="streamfunction", ) vp_da = self._metadata( vp, "velocity_potential", units="m**2 s**-1", standard_name="atmosphere_horizontal_velocity_potential", long_name="velocity potential", ) return sf_da, vp_da
[docs] def streamfunction(self, truncation=None): """Streamfunction. **Optional argument:** *truncation* Truncation limit (triangular truncation) for the spherical harmonic computation. **Returns:** *sf* The streamfunction. **See also:** `~VectorWind.sfvp`. **Examples:** Compute streamfunction:: sf = w.streamfunction() Compute streamfunction and apply spectral truncation at triangular T13:: sfT13 = w.streamfunction(truncation=13) """ sf = self._api.streamfunction(truncation=truncation) sf = self._metadata( sf, "streamfunction", units="m**2 s**-1", standard_name="atmosphere_horizontal_streamfunction", long_name="streamfunction", ) return sf
[docs] def velocitypotential(self, truncation=None): """Velocity potential. **Optional argument:** *truncation* Truncation limit (triangular truncation) for the spherical harmonic computation. **Returns:** *vp* The velocity potential. **See also:** `~VectorWind.sfvp`. **Examples:** Compute velocity potential:: vp = w.velocity potential() Compute velocity potential and apply spectral truncation at triangular T13:: vpT13 = w.velocity potential(truncation=13) """ vp = self._api.velocitypotential(truncation=truncation) vp = self._metadata( vp, "velocity_potential", units="m**2 s**-1", standard_name="atmosphere_horizontal_velocity_potential", long_name="velocity potential", ) return vp
[docs] def helmholtz(self, truncation=None): """Irrotational and non-divergent components of the vector wind. **Optional argument:** *truncation* Truncation limit (triangular truncation) for the spherical harmonic computation. **Returns:** *uchi*, *vchi*, *upsi*, *vpsi* Zonal and meridional components of irrotational and non-divergent wind components respectively. **See also:** `~VectorWind.irrotationalcomponent`, `~VectorWind.nondivergentcomponent`. **Examples:** Compute the irrotational and non-divergent components of the vector wind:: uchi, vchi, upsi, vpsi = w.helmholtz() Compute the irrotational and non-divergent components of the vector wind and apply spectral truncation at triangular T13:: uchiT13, vchiT13, upsiT13, vpsiT13 = w.helmholtz(truncation=13) """ uchi, vchi, upsi, vpsi = self._api.helmholtz(truncation=truncation) uchi = self._metadata( uchi, "u_chi", units="m s**-1", long_name="irrotational_eastward_wind" ) vchi = self._metadata( vchi, "v_chi", units="m s**-1", long_name="irrotational_northward_wind" ) upsi = self._metadata( upsi, "u_psi", units="m s**-1", long_name="non_divergent_eastward_wind" ) vpsi = self._metadata( vpsi, "v_psi", units="m s**-1", long_name="non_divergent_northward_wind" ) return uchi, vchi, upsi, vpsi
[docs] def irrotationalcomponent(self, truncation=None): """Irrotational (divergent) component of the vector wind. .. note:: If both the irrotational and non-divergent components are required then `~VectorWind.helmholtz` should be used instead. **Optional argument:** *truncation* Truncation limit (triangular truncation) for the spherical harmonic computation. **Returns:** *uchi*, *vchi* The zonal and meridional components of the irrotational wind respectively. **See also:** `~VectorWind.helmholtz`. **Examples:** Compute the irrotational component of the vector wind:: uchi, vchi = w.irrotationalcomponent() Compute the irrotational component of the vector wind and apply spectral truncation at triangular T13:: uchiT13, vchiT13 = w.irrotationalcomponent(truncation=13) """ uchi, vchi = self._api.irrotationalcomponent(truncation=truncation) uchi = self._metadata( uchi, "u_chi", units="m s**-1", long_name="irrotational_eastward_wind" ) vchi = self._metadata( vchi, "v_chi", units="m s**-1", long_name="irrotational_northward_wind" ) return uchi, vchi
[docs] def nondivergentcomponent(self, truncation=None): """Non-divergent (rotational) component of the vector wind. .. note:: If both the non-divergent and irrotational components are required then `~VectorWind.helmholtz` should be used instead. **Optional argument:** *truncation* Truncation limit (triangular truncation) for the spherical harmonic computation. **Returns:** *upsi*, *vpsi* The zonal and meridional components of the non-divergent wind respectively. **See also:** `~VectorWind.helmholtz`. **Examples:** Compute the non-divergent component of the vector wind:: upsi, vpsi = w.nondivergentcomponent() Compute the non-divergent component of the vector wind and apply spectral truncation at triangular T13:: upsiT13, vpsiT13 = w.nondivergentcomponent(truncation=13) """ upsi, vpsi = self._api.nondivergentcomponent(truncation=truncation) upsi = self._metadata( upsi, "u_psi", units="m s**-1", long_name="non_divergent_eastward_wind" ) vpsi = self._metadata( vpsi, "v_psi", units="m s**-1", long_name="non_divergent_northward_wind" ) return upsi, vpsi
[docs] def gradient( self, chi: DataArray, truncation: Optional[int] = None ) -> Tuple[DataArray, DataArray]: """ Calculate vector gradient of a scalar field on the sphere. Parameters ---------- chi : DataArray Scalar field with same latitude/longitude dimensions as wind components truncation : int, optional Triangular truncation limit for spherical harmonic computation Returns ------- u_gradient : DataArray Zonal component of vector gradient v_gradient : DataArray Meridional component of vector gradient Examples -------- >>> abs_vrt = vw.absolutevorticity() >>> avrt_u, avrt_v = vw.gradient(abs_vrt) >>> avrt_u_t13, avrt_v_t13 = vw.gradient(abs_vrt, truncation=13) """ if not isinstance(chi, xr.DataArray): raise TypeError( f"Scalar field must be xarray.DataArray, got {type(chi).__name__}" ) name = chi.name or "field" # Process coordinate ordering similar to initialization lat, lat_dim = _find_latitude_coordinate(chi) lon, lon_dim = _find_longitude_coordinate(chi) # Ensure north-to-south latitude ordering if lat.values[0] < lat.values[1]: chi = _reverse(chi, lat_dim) lat, lat_dim = _find_latitude_coordinate(chi) # Reorder for API compatibility apiorder, _ = get_apiorder(chi.ndim, lat_dim, lon_dim) apiorder = [chi.dims[i] for i in apiorder] reorder = chi.dims chi = chi.transpose(*apiorder) ishape = chi.shape coords = [chi.coords[n] for n in chi.dims] # Compute gradient using standard API u_grad, v_grad = self._api.gradient(chi.values, truncation=truncation) # Reshape and create DataArrays u_grad = u_grad.reshape(ishape) v_grad = v_grad.reshape(ishape) u_name = f"zonal_gradient_of_{name}" v_name = f"meridional_gradient_of_{name}" u_da = xr.DataArray( u_grad, coords=coords, name=u_name, attrs={"long_name": u_name} ) v_da = xr.DataArray( v_grad, coords=coords, name=v_name, attrs={"long_name": v_name} ) # Restore original dimension order u_da = u_da.transpose(*reorder) v_da = v_da.transpose(*reorder) return u_da, v_da
[docs] def rossbywavesource( self, truncation: Optional[int] = None, omega: Optional[float] = None ) -> DataArray: """ Calculate Rossby wave source. The Rossby wave source quantifies the generation of Rossby wave activity in the atmosphere through the interaction of divergence with absolute vorticity and the advection of absolute vorticity by the irrotational wind. Parameters ---------- truncation : int, optional Triangular truncation limit for spherical harmonic computation. If None, uses the default truncation based on grid resolution. omega : float, optional Earth's angular velocity in rad/s. Default is 7.292e-5 s⁻¹. Returns ------- DataArray Rossby wave source field with CF-compliant attributes See Also -------- absolutevorticity : Calculate absolute vorticity divergence : Calculate horizontal divergence irrotationalcomponent : Calculate irrotational wind component gradient : Calculate vector gradient Notes ----- The Rossby wave source is defined as: S = -ζₐ∇·v - v_χ·∇ζₐ where: - ζₐ is absolute vorticity (relative + planetary) - ∇·v is horizontal divergence - v_χ is the irrotational (divergent) wind component - ∇ζₐ is the gradient of absolute vorticity Positive values indicate Rossby wave generation, while negative values indicate wave absorption or dissipation. Examples -------- >>> rws = vw.rossbywavesource() >>> rws_t21 = vw.rossbywavesource(truncation=21) >>> rws_custom_omega = vw.rossbywavesource(omega=7.2921150e-5) # Create a plot of Rossby wave source >>> import matplotlib.pyplot as plt >>> import cartopy.crs as ccrs >>> >>> ax = plt.axes(projection=ccrs.PlateCarree()) >>> rws.plot.contourf(ax=ax, transform=ccrs.PlateCarree(), ... levels=20, cmap='RdBu_r') >>> ax.coastlines() >>> ax.gridlines() >>> plt.title('Rossby Wave Source') >>> plt.show() References ---------- Sardeshmukh, P. D., & Hoskins, B. J. (1988). The generation of global rotational flow by steady idealized tropical heating. Journal of the Atmospheric Sciences, 45(7), 1228-1251. """ rws = self._api.rossbywavesource(truncation=truncation, omega=omega) return self._metadata( rws, "rossby_wave_source", units="s**-2", standard_name="rossby_wave_source", long_name="rossby_wave_source_term", description="Generation term for Rossby wave activity", )
[docs] def truncate(self, field: DataArray, truncation: Optional[int] = None) -> DataArray: """ Apply spectral truncation to a scalar field. Parameters ---------- field : DataArray Scalar field with same latitude/longitude dimensions as wind components truncation : int, optional Triangular truncation limit. If None, defaults to nlat-1 Returns ------- DataArray Field with spectral truncation applied Examples -------- >>> field_trunc = vw.truncate(scalar_field) >>> field_t21 = vw.truncate(scalar_field, truncation=21) """ if not isinstance(field, xr.DataArray): raise TypeError( f"Field must be xarray.DataArray, got {type(field).__name__}" ) # Process coordinate ordering lat, lat_dim = _find_latitude_coordinate(field) lon, lon_dim = _find_longitude_coordinate(field) # Ensure north-to-south latitude ordering if lat.values[0] < lat.values[1]: field = _reverse(field, lat_dim) lat, lat_dim = _find_latitude_coordinate(field) # Reorder for API compatibility apiorder, _ = get_apiorder(field.ndim, lat_dim, lon_dim) apiorder = [field.dims[i] for i in apiorder] reorder = field.dims field = field.transpose(*apiorder) ishape = field.shape coords = [field.coords[n] for n in field.dims] # Apply truncation using standard API field_trunc = self._api.truncate(field.values, truncation=truncation) # Restore dimension order without coercing the computed result into the # input array's dtype. field = xr.DataArray( field_trunc.reshape(ishape), coords=coords, name=field.name, attrs=field.attrs, ).transpose(*reorder) return field
[docs] class ReducedVectorWind: """ Xarray interface for packed reduced-Gaussian vector wind analysis. The first dimension of ``u`` and ``v`` is interpreted as the packed reduced-Gaussian point dimension and must have length ``sum(pl)``. Any remaining dimensions are carried through unchanged. """
[docs] def __init__( self, u: DataArray, v: DataArray, pl: Any, rsphere: float = 6.3712e6, legfunc: LegFunc = "stored", precision: Literal["auto", "single", "double"] = "auto", ) -> None: if not isinstance(u, xr.DataArray): raise TypeError(f"u must be xarray.DataArray, got {type(u).__name__}") if not isinstance(v, xr.DataArray): raise TypeError(f"v must be xarray.DataArray, got {type(v).__name__}") self._validate_coordinates(u, v) self._reorder = u.dims self._ishape = u.shape self._coords = [u.coords[name] for name in u.dims] self._u_component_dtype = reduced.ReducedVectorWind._infer_output_dtype( u.values ) self._v_component_dtype = reduced.ReducedVectorWind._infer_output_dtype( v.values ) self._api = reduced.ReducedVectorWind( u.values, v.values, pl, rsphere=rsphere, legfunc=legfunc, precision=precision, ) self.pl = self._api.pl self.gridtype = self._api.gridtype self.rsphere = self._api.rsphere self.legfunc = self._api.legfunc self.s = self._api.s
def _validate_coordinates(self, u: DataArray, v: DataArray) -> None: if u.dims != v.dims: raise ValueError( f"u and v must have identical dimensions. " f"Got u: {u.dims}, v: {v.dims}" ) mismatched_coords = [] for dim in u.dims: try: if not (u.coords[dim].values == v.coords[dim].values).all(): mismatched_coords.append(dim) except (ValueError, TypeError): mismatched_coords.append(dim) if mismatched_coords: raise ValueError( f"u and v must have identical coordinate values. " f"Mismatched coordinates: {mismatched_coords}" ) def _metadata(self, data: Any, name: str, **attributes: Any) -> DataArray: result = xr.DataArray( np.asarray(data).reshape(self._ishape), coords=self._coords, name=name ) result = result.transpose(*self._reorder) for attr, value in attributes.items(): result.attrs[attr] = value return result def _component_metadata( self, data: Any, dtype: np.dtype, name: str, **attributes: Any ) -> DataArray: restored = self._api._restore_output_dtype(data, dtype) return self._metadata(restored, name, **attributes)
[docs] def u(self) -> DataArray: """Return the packed zonal wind component.""" return self._component_metadata( self._api.u, self._u_component_dtype, "u", units="m s**-1", standard_name="eastward_wind", long_name="eastward_component_of_wind", )
[docs] def v(self) -> DataArray: """Return the packed meridional wind component.""" return self._component_metadata( self._api.v, self._v_component_dtype, "v", units="m s**-1", standard_name="northward_wind", long_name="northward_component_of_wind", )
[docs] def magnitude(self) -> DataArray: """Return wind speed on the packed reduced grid.""" return self._metadata( self._api.magnitude(), "speed", units="m s**-1", standard_name="wind_speed", long_name="wind_speed", )
[docs] def vrtdiv(self, truncation: Optional[int] = None) -> Tuple[DataArray, DataArray]: """Return relative vorticity and horizontal divergence.""" vrt, div = self._api.vrtdiv(truncation=truncation) return ( self._metadata( vrt, "vorticity", units="s**-1", standard_name="atmosphere_relative_vorticity", long_name="relative_vorticity", ), self._metadata( div, "divergence", units="s**-1", standard_name="divergence_of_wind", long_name="horizontal_divergence", ), )
[docs] def vorticity(self, truncation: Optional[int] = None) -> DataArray: """Return relative vorticity.""" return self._metadata( self._api.vorticity(truncation=truncation), "vorticity", units="s**-1", standard_name="atmosphere_relative_vorticity", long_name="relative_vorticity", )
[docs] def divergence(self, truncation: Optional[int] = None) -> DataArray: """Return horizontal divergence.""" return self._metadata( self._api.divergence(truncation=truncation), "divergence", units="s**-1", standard_name="divergence_of_wind", long_name="horizontal_divergence", )
[docs] def planetaryvorticity(self, omega: Optional[float] = None) -> DataArray: """Return planetary vorticity.""" return self._metadata( self._api.planetaryvorticity(omega=omega), "coriolis", units="s**-1", standard_name="coriolis_parameter", long_name="planetary_vorticity", )
[docs] def absolutevorticity( self, omega: Optional[float] = None, truncation: Optional[int] = None ) -> DataArray: """Return absolute vorticity.""" return self._metadata( self._api.absolutevorticity(omega=omega, truncation=truncation), "absolute_vorticity", units="s**-1", standard_name="atmosphere_absolute_vorticity", long_name="absolute_vorticity", )
[docs] def sfvp(self, truncation: Optional[int] = None) -> Tuple[DataArray, DataArray]: """Return streamfunction and velocity potential.""" sf, vp = self._api.sfvp(truncation=truncation) return ( self._metadata( sf, "streamfunction", units="m**2 s**-1", standard_name="atmosphere_horizontal_streamfunction", long_name="streamfunction", ), self._metadata( vp, "velocity_potential", units="m**2 s**-1", standard_name="atmosphere_horizontal_velocity_potential", long_name="velocity potential", ), )
[docs] def streamfunction(self, truncation: Optional[int] = None) -> DataArray: """Return streamfunction.""" return self._metadata( self._api.streamfunction(truncation=truncation), "streamfunction", units="m**2 s**-1", standard_name="atmosphere_horizontal_streamfunction", long_name="streamfunction", )
[docs] def velocitypotential(self, truncation: Optional[int] = None) -> DataArray: """Return velocity potential.""" return self._metadata( self._api.velocitypotential(truncation=truncation), "velocity_potential", units="m**2 s**-1", standard_name="atmosphere_horizontal_velocity_potential", long_name="velocity potential", )
[docs] def helmholtz( self, truncation: Optional[int] = None ) -> Tuple[DataArray, DataArray, DataArray, DataArray]: """Return irrotational and non-divergent wind components.""" uchi, vchi, upsi, vpsi = self._api.helmholtz(truncation=truncation) return ( self._metadata( uchi, "u_chi", units="m s**-1", long_name="irrotational_eastward_wind", ), self._metadata( vchi, "v_chi", units="m s**-1", long_name="irrotational_northward_wind", ), self._metadata( upsi, "u_psi", units="m s**-1", long_name="non_divergent_eastward_wind", ), self._metadata( vpsi, "v_psi", units="m s**-1", long_name="non_divergent_northward_wind", ), )
[docs] def irrotationalcomponent( self, truncation: Optional[int] = None ) -> Tuple[DataArray, DataArray]: """Return irrotational wind component.""" uchi, vchi = self._api.irrotationalcomponent(truncation=truncation) return ( self._metadata( uchi, "u_chi", units="m s**-1", long_name="irrotational_eastward_wind", ), self._metadata( vchi, "v_chi", units="m s**-1", long_name="irrotational_northward_wind", ), )
[docs] def nondivergentcomponent( self, truncation: Optional[int] = None ) -> Tuple[DataArray, DataArray]: """Return non-divergent wind component.""" upsi, vpsi = self._api.nondivergentcomponent(truncation=truncation) return ( self._metadata( upsi, "u_psi", units="m s**-1", long_name="non_divergent_eastward_wind", ), self._metadata( vpsi, "v_psi", units="m s**-1", long_name="non_divergent_northward_wind", ), )
[docs] def gradient( self, chi: DataArray, truncation: Optional[int] = None ) -> Tuple[DataArray, DataArray]: """Return vector gradient of a packed scalar field.""" if not isinstance(chi, xr.DataArray): raise TypeError( f"Scalar field must be xarray.DataArray, got {type(chi).__name__}" ) name = chi.name or "field" u_grad, v_grad = self._api.gradient(chi.values, truncation=truncation) return ( self._metadata( u_grad, f"zonal_gradient_of_{name}", long_name=f"zonal_gradient_of_{name}", ), self._metadata( v_grad, f"meridional_gradient_of_{name}", long_name=f"meridional_gradient_of_{name}", ), )
[docs] def truncate(self, field: DataArray, truncation: Optional[int] = None) -> DataArray: """Apply spectral truncation to a packed scalar field.""" if not isinstance(field, xr.DataArray): raise TypeError( f"Field must be xarray.DataArray, got {type(field).__name__}" ) truncated = self._api.truncate(field.values, truncation=truncation) result = xr.DataArray( truncated.reshape(self._ishape), coords=self._coords, name=field.name, attrs=field.attrs, ) return result.transpose(*self._reorder)
[docs] def rossbywavesource( self, truncation: Optional[int] = None, omega: Optional[float] = None ) -> DataArray: """Return Rossby wave source on the packed reduced grid.""" return self._metadata( self._api.rossbywavesource(truncation=truncation, omega=omega), "rossby_wave_source", units="s**-2", standard_name="rossby_wave_source", long_name="rossby_wave_source_term", description="Generation term for Rossby wave activity", )
def _reverse(array: DataArray, dim: int) -> DataArray: """ Reverse an xarray DataArray along a given dimension. Parameters ---------- array : DataArray Input DataArray to reverse dim : int Dimension index to reverse Returns ------- DataArray Array with specified dimension reversed """ slicers = [slice(None)] * array.ndim slicers[dim] = slice(None, None, -1) return array[tuple(slicers)] def _find_coord_and_dim( array: DataArray, predicate: Callable[[Any], bool], name: str ) -> Tuple[Any, int]: """ Find a dimension coordinate in DataArray that satisfies a predicate. Parameters ---------- array : DataArray Input DataArray to search predicate : callable Function that returns True for the desired coordinate name : str Name of coordinate type for error messages Returns ------- coord : coordinate Found coordinate that satisfies predicate dim : int Dimension index of the coordinate Raises ------ ValueError If no coordinate or multiple coordinates found """ candidates = [ coord for coord in [array.coords[n] for n in array.dims] if predicate(coord) ] if not candidates: raise ValueError(f"Cannot find a {name} coordinate") if len(candidates) > 1: raise ValueError(f"Multiple {name} coordinates are not allowed") coord = candidates[0] dim = array.dims.index(coord.name) return coord, dim def _find_latitude_coordinate(array: DataArray) -> Tuple[Any, int]: """ Find latitude dimension coordinate in an xarray DataArray. Parameters ---------- array : DataArray Input DataArray to search for latitude coordinate Returns ------- lat_coord : coordinate Latitude coordinate lat_dim : int Latitude dimension index Raises ------ ValueError If latitude coordinate cannot be found or multiple found """ def is_latitude(coord: Any) -> bool: """Check if coordinate represents latitude.""" return ( coord.name in ("latitude", "lat", "LAT", "LATITUDE", "Y", "y", "LATS", "YLAT") or coord.attrs.get("units") == "degrees_north" or coord.attrs.get("axis") == "Y" ) return _find_coord_and_dim(array, is_latitude, "latitude") def _find_longitude_coordinate(array: DataArray) -> Tuple[Any, int]: """ Find longitude dimension coordinate in an xarray DataArray. Parameters ---------- array : DataArray Input DataArray to search for longitude coordinate Returns ------- lon_coord : coordinate Longitude coordinate lon_dim : int Longitude dimension index Raises ------ ValueError If longitude coordinate cannot be found or multiple found """ def is_longitude(coord: Any) -> bool: """Check if coordinate represents longitude.""" return ( coord.name in ("longitude", "lon", "LON", "LONGITUDE", "X", "x", "LONS", "XLONG") or coord.attrs.get("units") == "degrees_east" or coord.attrs.get("axis") == "X" ) return _find_coord_and_dim(array, is_longitude, "longitude")