Source code for skyborn.calc.geostrophic.core

"""
High-level Python interface for geostrophic wind calculations.

This module provides user-friendly interfaces for calculating geostrophic winds
from geopotential height fields, with support for multi-dimensional data using
the windspharm data preparation utilities.

The interface handles automatic data reshaping, dimension reordering, and
integration with the optimized Fortran backend.
"""

import sys
import warnings
from importlib import import_module
from types import ModuleType
from typing import Optional, Tuple, Union

import numpy as np

from skyborn.windspharm import tools as _windspharm_tools


def _load_geostrophic_backend() -> ModuleType:
    """Load the compiled backend without re-entering the package namespace."""
    qualified_name = f"{__package__}.geostrophicwind"
    try:
        backend = import_module(qualified_name)
    except ImportError:
        legacy_backend = sys.modules.get("geostrophicwind")
        if legacy_backend is None:
            raise
        return legacy_backend

    # Preserve the legacy top-level alias used by existing tests and callers.
    sys.modules.setdefault("geostrophicwind", backend)
    return backend


_geostrophic_module = _load_geostrophic_backend()

z2geouv = _geostrophic_module.z2geouv
z2geouv_3d = _geostrophic_module.z2geouv_3d


def _active_geostrophic_function(name: str):
    """Return the active backend function while tolerating unrelated test shims."""
    backend = sys.modules.get("geostrophicwind")
    if backend is not None and name in vars(backend):
        return getattr(backend, name)
    return getattr(_geostrophic_module, name)


def _is_longitude_cyclic(glon: np.ndarray, tolerance: float = 1.0) -> bool:
    """
    Determine if longitude data is cyclic by checking if it spans 360 degrees.

    Works with different grid resolutions by using adaptive tolerance based
    on grid spacing.

    Parameters
    ----------
    glon : ndarray
        Longitude coordinates in degrees
    tolerance : float
        Base tolerance for cyclicity check (default: 1.0 degrees)

    Returns
    -------
    bool
        True if longitude appears to be cyclic (spans ~360°)
    """
    if len(glon) < 3:  # Need at least 3 points for meaningful cyclicity
        return False

    lon_range = glon[-1] - glon[0]
    dlon = np.mean(np.diff(glon))

    # Adaptive tolerance based on grid spacing
    # For coarse grids (large dlon), use larger tolerance
    adaptive_tolerance = tolerance

    # Check if the range plus one grid spacing is approximately 360°
    expected_range = lon_range + dlon
    is_cyclic_360 = abs(expected_range - 360.0) < adaptive_tolerance
    if is_cyclic_360 and abs(dlon) <= 1.0 and tolerance < abs(dlon):
        is_cyclic_360 = False

    # Also check if range is already ~360° (for grids that include both 0 and 360)
    is_already_360 = abs(lon_range - 360.0) < adaptive_tolerance

    return is_cyclic_360 or is_already_360


def _ensure_south_to_north(
    z: np.ndarray, glat: np.ndarray, dim_order: str
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Ensure latitude dimension is ordered south-to-north as required by Fortran code.

    Parameters
    ----------
    z : ndarray
        Geopotential height data
    glat : ndarray
        Latitude coordinates
    dim_order : str
        Dimension order string

    Returns
    -------
    z_ordered : ndarray
        Data with latitude ordered south-to-north
    glat_ordered : ndarray
        Latitude coordinates ordered south-to-north
    """
    # Check if latitude needs to be reversed (north-to-south -> south-to-north)
    if glat[0] > glat[-1]:  # Currently north-to-south, need to reverse
        glat_ordered = glat[::-1].copy()

        # Find latitude axis in the data
        lat_axis = dim_order.lower().find("y")
        if lat_axis == -1:
            raise ValueError("Latitude dimension 'y' not found in dim_order")

        # Reverse the latitude dimension in the data
        z_ordered = np.flip(z, axis=lat_axis)
    else:
        # Already south-to-north or same latitude
        z_ordered = z
        glat_ordered = glat

    return z_ordered, glat_ordered


[docs] def geostrophic_wind( z: np.ndarray, glon: np.ndarray, glat: np.ndarray, dim_order: str, missing_value: float = -999.0, ) -> Tuple[np.ndarray, np.ndarray]: """ Calculate geostrophic wind components from geopotential height. This function can handle various input shapes by using windspharm's prep_data and recover_data utilities to reshape data for batch processing. Parameters ---------- z : ndarray Geopotential height data [gpm]. Can be 2D, 3D, or 4D. Must contain latitude ('y') and longitude ('x') dimensions. glon : ndarray, shape (nlon,) Longitude coordinates in degrees glat : ndarray, shape (nlat,) Latitude coordinates in degrees (south to north) dim_order : str String specifying dimension order using: - 'x' for longitude - 'y' for latitude - 't' for time - 'z' for level Example: 'tzyx' for (time, level, lat, lon) missing_value : float, optional Missing value identifier (default: -999.0) Returns ------- ug : ndarray Zonal geostrophic wind component [m/s] (same shape as input z) vg : ndarray Meridional geostrophic wind component [m/s] (same shape as input z) Examples -------- # 2D case: single time/level >>> z2d = np.random.randn(73, 144) # (lat, lon) >>> ug, vg = geostrophic_wind(z2d, glon, glat, 'yx') # 3D case: multiple times >>> z3d = np.random.randn(73, 144, 12) # (lat, lon, time) >>> ug, vg = geostrophic_wind(z3d, glon, glat, 'yxt') # 4D case: multiple levels and times >>> z4d = np.random.randn(12, 17, 73, 144) # (time, level, lat, lon) >>> ug, vg = geostrophic_wind(z4d, glon, glat, 'tzyx') # Alternative 4D ordering >>> z4d_alt = np.random.randn(73, 144, 17, 12) # (lat, lon, level, time) >>> ug, vg = geostrophic_wind(z4d_alt, glon, glat, 'yxzt') """ # Auto-detect longitude cyclicity cyclic = _is_longitude_cyclic(glon) iopt = 1 if cyclic else 0 # Ensure latitude is ordered south-to-north as required by Fortran code z, glat = _ensure_south_to_north(z, glat, dim_order) # Handle multi-dimensional data using windspharm tools if len(z.shape) > 2: return _geostrophic_wind_multidim(z, glon, glat, dim_order, missing_value, iopt) else: # Direct 2D geostrophic wind calculation (ug, vg components) return _calc_geostrophic_2d(z, glon, glat, missing_value, iopt)
def _geostrophic_wind_multidim( z: np.ndarray, glon: np.ndarray, glat: np.ndarray, dim_order: str, missing_value: float, iopt: int, ) -> Tuple[np.ndarray, np.ndarray]: """ Internal function for multi-dimensional geostrophic wind calculation. Uses windspharm prep_data/recover_data for dimension handling. """ # Step 1: Prepare data (reshape to (nlat, nlon, combined_other_dims)) prepared_z, recovery_info = _windspharm_tools.prep_data(z, dim_order) nlat, nlon, n_combined = prepared_z.shape # Ensure coordinate arrays are compatible if len(glat) != nlat: raise ValueError( f"Latitude array length ({len(glat)}) doesn't match data ({nlat})" ) if len(glon) != nlon: raise ValueError( f"Longitude array length ({len(glon)}) doesn't match data ({nlon})" ) # Step 2: Use 3D function for all cases (handles any n_combined size) prepared_z = np.asarray(prepared_z, dtype=np.float32) z2geouv_3d_func = _active_geostrophic_function("z2geouv_3d") ug_prepared, vg_prepared = z2geouv_3d_func( z=prepared_z, zmsg=missing_value, glon=glon, glat=glat, iopt=iopt ) # Step 3: Recover original shape and dimension order ug_final = _windspharm_tools.recover_data(ug_prepared, recovery_info) vg_final = _windspharm_tools.recover_data(vg_prepared, recovery_info) return ug_final, vg_final def _calc_geostrophic_2d(z, glon, glat, missing_value, iopt): """Calculate 2D geostrophic wind components using Fortran backend.""" z2geouv_func = _active_geostrophic_function("z2geouv") return z2geouv_func( np.asarray(z, dtype=np.float32), zmsg=missing_value, glon=glon, glat=glat, iopt=iopt, )
[docs] class GeostrophicWind: """ Class-based interface for geostrophic wind calculations. This class provides a high-level interface similar to windspharm's VectorWind, allowing for easy calculation of various geostrophic wind quantities. Parameters ---------- z : ndarray Geopotential height data [gpm] glon : ndarray Longitude coordinates [degrees] glat : ndarray Latitude coordinates [degrees] (south to north) dim_order : str Dimension ordering specification missing_value : float, optional Missing value identifier (default: -999.0) Examples -------- >>> # Create GeostrophicWind instance (longitude cyclicity auto-detected) >>> gw = GeostrophicWind(z, glon, glat, 'tzyx') >>> >>> # Get wind components >>> ug, vg = gw.uv_components() >>> >>> # Calculate derived quantities >>> speed = gw.speed() >>> >>> # Access original data >>> z_orig = gw.geopotential_height """
[docs] def __init__( self, z: np.ndarray, glon: np.ndarray, glat: np.ndarray, dim_order: str, missing_value: float = -999.0, ): self._z_original = np.asarray(z) self._glon = np.asarray(glon) self._glat = np.asarray(glat) self._dim_order = dim_order self._missing_value = missing_value # Calculate winds on initialization self._ug, self._vg = geostrophic_wind( z, glon, glat, dim_order, missing_value=missing_value )
@property def geopotential_height(self) -> np.ndarray: """Original geopotential height data.""" return self._z_original @property def longitude(self) -> np.ndarray: """Longitude coordinates.""" return self._glon @property def latitude(self) -> np.ndarray: """Latitude coordinates.""" return self._glat
[docs] def uv_components(self) -> Tuple[np.ndarray, np.ndarray]: """ Return zonal and meridional wind components. Returns ------- ug : ndarray Zonal (eastward) geostrophic wind component [m/s] vg : ndarray Meridional (northward) geostrophic wind component [m/s] """ return self._ug, self._vg
[docs] def speed(self) -> np.ndarray: """ Calculate geostrophic wind speed. Returns ------- speed : ndarray Geostrophic wind speed [m/s] """ # Handle missing values properly ug_valid = np.where(self._ug == self._missing_value, 0, self._ug) vg_valid = np.where(self._vg == self._missing_value, 0, self._vg) speed = np.sqrt(ug_valid**2 + vg_valid**2) # Restore missing values where either component was missing missing_mask = (self._ug == self._missing_value) | ( self._vg == self._missing_value ) speed = np.where(missing_mask, self._missing_value, speed) return speed
# Convenience functions matching windspharm naming conventions def geostrophic_uv(z, glon, glat, dim_order, **kwargs): """ Calculate geostrophic wind components directly. This function calculates geostrophic wind components (ug, vg) from geopotential height fields. Uses the same implementation as GeostrophicWind class for consistency. Parameters ---------- z : ndarray Geopotential height data [gpm]. Can be 2D, 3D, or 4D. Must contain latitude ('y') and longitude ('x') dimensions. glon : ndarray, shape (nlon,) Longitude coordinates in degrees glat : ndarray, shape (nlat,) Latitude coordinates in degrees (automatically ordered south-to-north) dim_order : str String specifying dimension order using: - 'x' for longitude - 'y' for latitude - 't' for time - 'z' for level Example: 'tzyx' for (time, level, lat, lon) missing_value : float, optional Missing value identifier (default: -999.0) Returns ------- ug : ndarray Zonal geostrophic wind component [m/s] (same shape as input z) vg : ndarray Meridional geostrophic wind component [m/s] (same shape as input z) Notes ----- - Longitude cyclicity is automatically detected - Latitude ordering is automatically ensured to be south-to-north - Uses optimized Fortran backend with SIMD optimization """ gw = GeostrophicWind(z, glon, glat, dim_order, **kwargs) return gw.uv_components() def geostrophic_speed(z, glon, glat, dim_order, **kwargs): """ Calculate geostrophic wind speed directly. This function calculates geostrophic wind speed from geopotential height fields. Uses the GeostrophicWind class internally for consistent results. Parameters ---------- z : ndarray Geopotential height data [gpm]. Can be 2D, 3D, or 4D. Must contain latitude ('y') and longitude ('x') dimensions. glon : ndarray, shape (nlon,) Longitude coordinates in degrees glat : ndarray, shape (nlat,) Latitude coordinates in degrees (automatically ordered south-to-north) dim_order : str String specifying dimension order using: - 'x' for longitude - 'y' for latitude - 't' for time - 'z' for level Example: 'tzyx' for (time, level, lat, lon) missing_value : float, optional Missing value identifier (default: -999.0) Returns ------- speed : ndarray Geostrophic wind speed [m/s] (same shape as input z) Notes ----- - Longitude cyclicity is automatically detected - Latitude ordering is automatically ensured to be south-to-north - Speed calculated as sqrt(ug² + vg²) - Missing values are properly handled """ gw = GeostrophicWind(z, glon, glat, dim_order, **kwargs) return gw.speed()