Source code for skyborn.calc.geostrophic.xarray

"""
Geostrophic wind calculation for xarray DataArrays.

This module provides functions to calculate geostrophic wind components from
geopotential height fields using xarray DataArrays. It automatically detects
spatial coordinates and preserves coordinate information and metadata throughout
the computation process.

Main Functions:
    geostrophic_wind : Calculate geostrophic wind components for xarray DataArray
    GeostrophicWind : Class-based interface for xarray DataArrays

Examples:
    >>> import xarray as xr
    >>> import numpy as np
    >>> from skyborn.calc.geostrophic.xarray import geostrophic_wind
    >>>
    >>> # Load geopotential height data
    >>> z = xr.open_dataarray('geopotential_500hPa.nc')
    >>> result = geostrophic_wind(z)  # Auto-detects coordinates
    >>> print(result.ug.attrs)  # Original attributes preserved
"""

from __future__ import annotations

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

__all__ = ["geostrophic_wind", "GeostrophicWind"]

import numpy as np
import xarray as xr

from . import core as interface

# Type aliases
DataArray = xr.DataArray


def _detect_spatial_dimensions(
    data_array: DataArray,
) -> Tuple[Optional[int], Optional[int]]:
    """
    Auto-detect longitude and latitude dimension indices in xarray DataArray.

    Parameters
    ----------
    data_array : xr.DataArray
        Input data to analyze

    Returns
    -------
    xdim : int, optional
        Longitude dimension index (None if not found)
    ydim : int, optional
        Latitude dimension index (None if not found)

    Raises
    ------
    ValueError
        If both longitude and latitude dimensions cannot be identified
    """
    dims = data_array.dims

    # Common dimension name patterns
    lon_names = {
        "lon",
        "longitude",
        "x",
        "X",
        "LON",
        "XLON",
        "LONS",
        "LONG",
        "LONGITUDE",
    }
    lat_names = {"lat", "latitude", "y", "Y", "LAT", "YLAT", "LATS", "LATI", "LATITUDE"}

    xdim = ydim = None

    for i, dim_name in enumerate(dims):
        dim_lower = dim_name.lower()

        if any(name.lower() in dim_lower for name in lon_names):
            xdim = i
        elif any(name.lower() in dim_lower for name in lat_names):
            ydim = i

    # Both longitude and latitude are required for geostrophic wind calculation
    if xdim is None or ydim is None:
        raise ValueError(
            f"Could not auto-detect both longitude and latitude dimensions. "
            f"Found dims: {dims}. Expected longitude and latitude coordinates."
        )

    return xdim, ydim


def _extract_coordinates(data_array: DataArray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Extract longitude and latitude coordinate arrays from xarray DataArray.

    Parameters
    ----------
    data_array : xr.DataArray
        Input data containing coordinate information

    Returns
    -------
    glon : np.ndarray
        Longitude coordinates in degrees
    glat : np.ndarray
        Latitude coordinates in degrees

    Raises
    ------
    ValueError
        If longitude or latitude coordinates are not found
    """
    dims = data_array.dims
    coords = data_array.coords

    # Find longitude and latitude coordinate names
    lon_coord = lat_coord = None

    lon_names = {
        "lon",
        "longitude",
        "x",
        "X",
        "LON",
        "XLON",
        "LONS",
        "LONG",
        "LONGITUDE",
    }
    lat_names = {"lat", "latitude", "y", "Y", "LAT", "YLAT", "LATS", "LATI", "LATITUDE"}

    for coord_name in coords:
        coord_lower = coord_name.lower()
        if any(name.lower() in coord_lower for name in lon_names):
            lon_coord = coord_name
        elif any(name.lower() in coord_lower for name in lat_names):
            lat_coord = coord_name

    if lon_coord is None or lat_coord is None:
        raise ValueError(
            f"Could not find longitude and latitude coordinates. "
            f"Available coordinates: {list(coords.keys())}"
        )

    # Extract coordinate values
    glon = coords[lon_coord].values
    glat = coords[lat_coord].values

    return glon, glat


def _create_dim_order_string(data_array: DataArray, xdim: int, ydim: int) -> str:
    """
    Create dimension order string for interface functions.

    Parameters
    ----------
    data_array : xr.DataArray
        Input data array
    xdim : int
        Longitude dimension index
    ydim : int
        Latitude dimension index

    Returns
    -------
    dim_order : str
        Dimension order string (e.g., 'tzyx', 'yx', etc.)
    """
    dims = list(data_array.dims)
    dim_order = [""] * len(dims)

    # Set longitude and latitude
    dim_order[xdim] = "x"
    dim_order[ydim] = "y"

    # Common patterns for other dimensions
    time_names = {"time", "t", "T", "year", "month", "yr", "mn", "season"}
    level_names = {
        "level",
        "lev",
        "plev",
        "pressure",
        "pressure_level",
        "z",
        "Z",
        "LEV",
        "PRES",
        "LEVEL",
        "PLEVEL",
        "height",
        "altitude",
        "isobaric",
    }

    # Fill in other dimensions
    for i, dim_name in enumerate(dims):
        if dim_order[i] == "":  # Not yet assigned
            dim_lower = dim_name.lower()
            if any(name.lower() in dim_lower for name in time_names):
                dim_order[i] = "t"
            elif any(name.lower() in dim_lower for name in level_names):
                dim_order[i] = "z"
            else:
                # Default to 't' for unrecognized dimensions
                dim_order[i] = "t"

    return "".join(dim_order)


[docs] def geostrophic_wind( z: DataArray, missing_value: float = -999.0, keep_attrs: bool = True, ) -> xr.Dataset: """ Calculate geostrophic wind components for xarray DataArrays. This function processes geopotential height data to calculate geostrophic wind components (ug, vg). It automatically detects coordinate dimensions and preserves all metadata. Parameters ---------- z : xarray.DataArray Geopotential height data [gpm]. Can be 2D, 3D, or 4D. Must contain longitude and latitude dimensions with coordinate information. missing_value : float, optional Missing value identifier (default: -999.0) keep_attrs : bool, optional Preserve input DataArray attributes in output (default: True) Returns ------- xarray.Dataset Dataset containing geostrophic wind components: - 'ug': Zonal geostrophic wind component [m/s] with spatial/temporal coordinates - 'vg': Meridional geostrophic wind component [m/s] with spatial/temporal coordinates Examples -------- **2D Geopotential Height Analysis:** >>> import xarray as xr >>> import numpy as np >>> from skyborn.calc.geostrophic.xarray import geostrophic_wind >>> >>> # Load 500 hPa geopotential height >>> z = xr.open_dataarray('z500_era5.nc') # Shape: (lat, lon) >>> result = geostrophic_wind(z) >>> print(f"Wind components: ug{result.ug.shape}, vg{result.vg.shape}") **3D Time Series Analysis:** >>> # Multi-time geopotential height data >>> z_3d = xr.open_dataarray('z500_timeseries.nc') # Shape: (time, lat, lon) >>> result = geostrophic_wind(z_3d) >>> # Result preserves time dimension: (time, lat, lon) >>> monthly_mean = result.ug.groupby('time.month').mean() **4D Multi-level Analysis:** >>> # Multi-level, multi-time data >>> z_4d = xr.open_dataarray('z_multilevel.nc') # Shape: (time, level, lat, lon) >>> result = geostrophic_wind(z_4d) >>> # Result shape: (time, level, lat, lon) >>> surface_winds = result.sel(level=1000) # 1000 hPa level **Simplified Interface (No coordinate specification needed):** >>> # Automatic coordinate detection >>> result = geostrophic_wind(z_data) # Longitude cyclicity auto-detected >>> print(f"Longitude cyclicity auto-detected: {result.attrs['longitude_cyclic']}") >>> print(f"Latitude ordering: {result.attrs['latitude_ordering']}") Notes ----- - Longitude cyclicity is automatically detected from coordinate spacing - Latitude ordering is automatically ensured to be south-to-north as required - Requires compiled Fortran extensions for optimal performance - All coordinate information and attributes are preserved The function automatically: - Detects longitude and latitude coordinates using metadata - Handles missing values (NaN or masked arrays) - Preserves all coordinate information and attributes - Works with multi-dimensional data of any supported shape See Also -------- skyborn.calc.geostrophic.core.geostrophic_wind : Lower-level function for numpy arrays GeostrophicWind : Class-based interface for xarray DataArrays """ # Validate input type if not isinstance(z, xr.DataArray): raise TypeError(f"z must be xarray.DataArray, got {type(z).__name__}") # Auto-detect spatial dimensions xdim, ydim = _detect_spatial_dimensions(z) # Extract coordinate arrays glon, glat = _extract_coordinates(z) # Create dimension order string dim_order = _create_dim_order_string(z, xdim, ydim) # Store original coordinate information original_coords = z.coords original_dims = z.dims # Extract numpy array z_data = z.values # Call the core geostrophic calculation function ug_data, vg_data = interface.geostrophic_wind( z_data, glon, glat, dim_order, missing_value=missing_value ) # Create output coordinates (same as input) output_coords = {} for dim_name in z.dims: if dim_name in z.coords: output_coords[dim_name] = z.coords[dim_name] # Create DataArrays for wind components ug = xr.DataArray( ug_data, dims=z.dims, coords=output_coords, attrs={ "long_name": "Zonal geostrophic wind component", "units": "m s-1", "standard_name": "eastward_geostrophic_wind", "description": "Zonal (eastward) component of geostrophic wind calculated from geopotential height", }, ) vg = xr.DataArray( vg_data, dims=z.dims, coords=output_coords, attrs={ "long_name": "Meridional geostrophic wind component", "units": "m s-1", "standard_name": "northward_geostrophic_wind", "description": "Meridional (northward) component of geostrophic wind calculated from geopotential height", }, ) # Create Dataset ds = xr.Dataset({"ug": ug, "vg": vg}) # Add global attributes ds.attrs = { "title": "Geostrophic wind calculation results", "description": "Geostrophic wind components calculated from geopotential height", "longitude_cyclic": interface._is_longitude_cyclic(glon), "latitude_ordering": "south_to_north", "missing_value": missing_value, "method": "Finite difference approximation with geostrophic balance", "software": "skyborn atmospheric calculation package", "equations": "ug = -(g/f)*dZ/dy, vg = (g/f)*dZ/dx", } # Preserve original attributes if requested if keep_attrs and hasattr(z, "attrs") and z.attrs: ds.attrs.update({f"source_geopotential_{k}": v for k, v in z.attrs.items()}) return ds
[docs] class GeostrophicWind: """ Class-based geostrophic wind analysis using xarray DataArrays. This class provides a high-level interface for geostrophic wind calculations that preserves xarray coordinate information and metadata. It wraps the standard interface implementation while maintaining CF-compliant attributes. Parameters ---------- z : xarray.DataArray Geopotential height data [gpm]. Must contain longitude and latitude dimensions with appropriate coordinate information. missing_value : float, optional Missing value identifier (default: -999.0) keep_attrs : bool, optional Preserve input DataArray attributes in output (default: True) Attributes ---------- _z_original : xarray.DataArray Original geopotential height data _result : xarray.Dataset Computed geostrophic wind components _glon : np.ndarray Longitude coordinates _glat : np.ndarray Latitude coordinates Examples -------- >>> import xarray as xr >>> from skyborn.calc.geostrophic.xarray import GeostrophicWind >>> >>> # Load geopotential height >>> z = xr.open_dataarray('z500.nc') >>> >>> # Create GeostrophicWind instance >>> gw = GeostrophicWind(z) >>> >>> # Get wind components with preserved metadata >>> ug, vg = gw.uv_components() >>> print(ug.attrs) # CF-compliant attributes >>> >>> # Calculate derived quantities >>> speed = gw.speed() >>> print(f"Max wind speed: {float(speed.max()):.1f} m/s") >>> >>> # Access original data >>> z_orig = gw.geopotential_height """
[docs] def __init__( self, z: DataArray, missing_value: float = -999.0, keep_attrs: bool = True, ): """Initialize GeostrophicWind with xarray DataArray.""" self._z_original = z self._missing_value = missing_value self._keep_attrs = keep_attrs # Extract coordinates for later use self._glon, self._glat = _extract_coordinates(z) # Calculate geostrophic wind components self._result = geostrophic_wind( z, missing_value=missing_value, keep_attrs=keep_attrs )
@property def geopotential_height(self) -> DataArray: """Original geopotential height data.""" return self._z_original @property def longitude(self) -> np.ndarray: """Longitude coordinates.""" return self._glon @property def latitude(self) -> np.ndarray: """Latitude coordinates.""" return self._glat
[docs] def uv_components(self) -> Tuple[DataArray, DataArray]: """ Return zonal and meridional wind components. Returns ------- ug : xarray.DataArray Zonal (eastward) geostrophic wind component [m/s] vg : xarray.DataArray Meridional (northward) geostrophic wind component [m/s] """ return self._result.ug, self._result.vg
[docs] def speed(self) -> DataArray: """ Calculate geostrophic wind speed. Returns ------- speed : xarray.DataArray Geostrophic wind speed [m/s] """ ug, vg = self.uv_components() # Handle missing values properly ug_valid = ug.where(ug != self._missing_value, 0) vg_valid = vg.where(vg != self._missing_value, 0) # Calculate speed speed = np.hypot(ug_valid, vg_valid) # Restore missing values where either component was missing missing_mask = (ug == self._missing_value) | (vg == self._missing_value) speed = speed.where(~missing_mask, self._missing_value) # Set attributes speed.attrs = { "long_name": "Geostrophic wind speed", "units": "m s-1", "standard_name": "geostrophic_wind_speed", "description": "Speed of geostrophic wind calculated as sqrt(ug² + vg²)", } return speed