Source code for skyborn.gridfill.xarray

"""
Gridfill procedures for missing value interpolation with xarray interface.

This module provides functions to fill missing values in xarray DataArrays using
iterative relaxation methods to solve Poisson's equation. It preserves coordinate
information and metadata throughout the computation process.

Main Functions:
    fill : Fill missing values in xarray DataArray

Examples:
    >>> import xarray as xr
    >>> import numpy as np
    >>> from skyborn.gridfill.xarray import fill
    >>>
    >>> # Load data with missing values
    >>> data = xr.open_dataarray('temperature_with_gaps.nc')
    >>>
    >>> # Fill missing values preserving metadata
    >>> filled_data = fill(data, eps=1e-4)
    >>> print(filled_data.attrs)  # Original attributes preserved
"""

from __future__ import annotations

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

__all__ = ["fill", "fill_multiple", "validate_grid_coverage"]

import numpy as np
import numpy.ma as ma
import xarray as xr

from . import gridfill

# Type aliases for better readability
DataArray = xr.DataArray


def _find_spatial_coordinates(
    data: DataArray,
) -> Tuple[str, str, int, int]:
    """
    Find spatial coordinate dimensions in xarray DataArray.

    This function automatically detects latitude/longitude or y/x coordinate
    dimensions in the DataArray using common naming conventions and coordinate
    attributes.

    Parameters
    ----------
    data : xarray.DataArray
        Input DataArray to analyze

    Returns
    -------
    y_name : str
        Name of the y-coordinate dimension (latitude)
    x_name : str
        Name of the x-coordinate dimension (longitude)
    y_dim : int
        Index of y-coordinate dimension
    x_dim : int
        Index of x-coordinate dimension

    Raises
    ------
    ValueError
        If spatial coordinates cannot be identified

    Notes
    -----
    Detection priority:
    1. Standard names: 'latitude'/'longitude'
    2. Axis attributes: axis='Y'/'X'
    3. Common dimension names: 'lat'/'lon', 'y'/'x'
    4. Unit attributes: 'degrees_north'/'degrees_east'
    """
    x_name = None
    y_name = None

    # Try to find coordinates by standard_name first
    for name, coord in data.coords.items():
        if hasattr(coord, "standard_name"):
            if coord.standard_name == "latitude":
                y_name = name
            elif coord.standard_name == "longitude":
                x_name = name

    # Try to find by axis attribute
    if x_name is None or y_name is None:
        for name, coord in data.coords.items():
            if hasattr(coord, "axis"):
                if coord.axis == "Y" and y_name is None:
                    y_name = name
                elif coord.axis == "X" and x_name is None:
                    x_name = name

    # Try to find by common dimension names
    if x_name is None or y_name is None:
        for name in data.dims:
            name_lower = name.lower()
            if y_name is None and name_lower in ["lat", "latitude", "y"]:
                y_name = name
            elif x_name is None and name_lower in ["lon", "lng", "longitude", "x"]:
                x_name = name

    # Try to find by units
    if x_name is None or y_name is None:
        for name, coord in data.coords.items():
            if hasattr(coord, "units"):
                units = getattr(coord, "units", "")
                if y_name is None and "degrees_north" in str(units):
                    y_name = name
                elif x_name is None and "degrees_east" in str(units):
                    x_name = name

    # Validate that we found both coordinates
    if x_name is None or y_name is None:
        available_dims = list(data.dims)
        available_coords = list(data.coords.keys())
        raise ValueError(
            f"Could not identify spatial coordinates automatically. "
            f"Available dimensions: {available_dims}, "
            f"Available coordinates: {available_coords}. "
            f"Please ensure your data has recognizable latitude/longitude "
            f"coordinates with appropriate metadata (standard_name, axis, or units attributes)."
        )

    # Get dimension indices
    try:
        y_dim = data.dims.index(y_name)
        x_dim = data.dims.index(x_name)
    except ValueError as e:
        raise ValueError(f"Coordinate dimension not found in data dimensions: {e}")

    return y_name, x_name, y_dim, x_dim


def _detect_cyclic_longitude(lon_coord: xr.DataArray) -> bool:
    """
    Detect if longitude coordinate is cyclic (wraps around).

    Parameters
    ----------
    lon_coord : xarray.DataArray
        Longitude coordinate to analyze

    Returns
    -------
    bool
        True if longitude appears to be cyclic (global coverage)

    Notes
    -----
    Detection criteria:
    1. Check for 'circular' attribute (iris convention)
    2. Check if span is approximately 360 degrees
    3. Check if values span from approximately -180 to 180 or 0 to 360
    """
    # Check for explicit circular attribute
    if hasattr(lon_coord, "circular") and lon_coord.circular:
        return True

    # Get longitude values
    lon_vals = lon_coord.values

    # Check if span covers ~360 degrees
    lon_span = np.max(lon_vals) - np.min(lon_vals)
    if np.abs(lon_span - 360.0) < 10.0:  # Allow some tolerance
        return True

    # Check common global longitude ranges
    lon_min, lon_max = np.min(lon_vals), np.max(lon_vals)

    # Check for 0-360 range
    if np.abs(lon_min) < 10.0 and np.abs(lon_max - 360.0) < 10.0:
        return True

    # Check for -180 to 180 range
    if np.abs(lon_min + 180.0) < 10.0 and np.abs(lon_max - 180.0) < 10.0:
        return True

    return False


[docs] def fill( data: DataArray, eps: float, x_dim: Optional[str] = None, y_dim: Optional[str] = None, relax: float = 0.6, itermax: int = 100, initzonal: bool = False, initzonal_linear: bool = False, cyclic: Optional[bool] = None, initial_value: float = 0.0, verbose: bool = False, keep_attrs: bool = True, ) -> DataArray: """ Fill missing values in xarray DataArray using Poisson equation solver. This function fills missing values (NaN or masked values) in gridded data by solving Poisson's equation (∇²φ = 0) using an iterative relaxation scheme. The method provides smooth interpolation while preserving coordinate information and metadata from the input DataArray. Parameters ---------- data : xarray.DataArray Input DataArray containing data with missing values to fill. Missing values can be NaN or masked values. eps : float Convergence tolerance. Iteration stops when the maximum residual falls below this threshold. x_dim : str, optional Name of the x-coordinate dimension (longitude). If None, will be detected automatically using coordinate metadata. y_dim : str, optional Name of the y-coordinate dimension (latitude). If None, will be detected automatically using coordinate metadata. relax : float, default 0.6 Relaxation parameter for the iterative scheme. Must be in range (0, 1). Values between 0.45-0.6 typically work well. itermax : int, default 100 Maximum number of iterations. initzonal : bool, default False Initialization method for missing values: - False: Initialize with zeros or initial_value - True: Initialize with zonal (x-direction) mean initzonal_linear : bool, default False Use linear interpolation for zonal initialization: - False: Use constant zonal mean (if initzonal=True) - True: Use linear interpolation between valid points in each latitude band This provides better initial conditions by connecting valid data points with linear interpolation rather than using a constant mean value. Can be used with both cyclic and non-cyclic data. cyclic : bool, optional Whether the x-coordinate is cyclic (e.g., longitude wrapping). If None, will be detected automatically for longitude coordinates. initial_value : float, default 0.0 Initial value to use for missing grid points when initzonal=False. This provides a custom starting guess for the iterative solver. When initzonal=True, this value may still be used in combination with the zonal mean for enhanced initialization. verbose : bool, default False Print convergence information for each slice. keep_attrs : bool, default True Preserve input DataArray attributes in output. Returns ------- filled_data : xarray.DataArray DataArray with missing values filled, preserving coordinates and optionally attributes from input. Raises ------ ValueError If spatial coordinates cannot be identified or are invalid TypeError If input is not an xarray DataArray Warnings -------- Issues warning if algorithm fails to converge on any slices Notes ----- The algorithm solves: ∇²φ = (∂²φ/∂x²) + (∂²φ/∂y²) = 0 using a finite difference relaxation scheme. The method automatically: - Detects spatial coordinates using metadata - Handles cyclic longitude boundaries for global data - Preserves all coordinate information and attributes - Works with multi-dimensional data (time series, levels, etc.) For missing value detection, both NaN values and xarray/numpy masked arrays are supported. Examples -------- Basic usage with automatic coordinate detection: >>> import xarray as xr >>> import numpy as np >>> from skyborn.gridfill.xarray import fill >>> >>> # Load data with missing values >>> data = xr.open_dataarray('sst_with_gaps.nc') >>> >>> # Fill missing values >>> filled = fill(data, eps=1e-4) >>> print(f"Original shape: {data.shape}") >>> print(f"Filled shape: {filled.shape}") >>> print(f"Attributes preserved: {filled.attrs == data.attrs}") Advanced usage with explicit parameters: >>> # Create test data with gaps >>> lons = np.linspace(0, 360, 72, endpoint=False) >>> lats = np.linspace(-90, 90, 36) >>> time = pd.date_range('2020-01-01', periods=12, freq='M') >>> >>> # Create DataArray with metadata >>> data = xr.DataArray( ... np.random.rand(12, 36, 72), ... coords={'time': time, 'lat': lats, 'lon': lons}, ... dims=['time', 'lat', 'lon'], ... attrs={'units': 'K', 'long_name': 'temperature'} ... ) >>> >>> # Add some missing values >>> data = data.where(np.random.rand(*data.shape) > 0.1) >>> >>> # Fill with custom settings >>> filled = fill( ... data, ... eps=1e-5, ... relax=0.55, ... initzonal=True, ... verbose=True ... ) Working with specific coordinate dimensions: >>> # Explicitly specify coordinate dimensions >>> filled = fill(data, eps=1e-4, x_dim='longitude', y_dim='latitude') See Also -------- skyborn.gridfill.fill : Lower-level function for numpy arrays """ # Validate input type if not isinstance(data, xr.DataArray): raise TypeError(f"data must be xarray.DataArray, got {type(data).__name__}") # Find spatial coordinates if not provided if x_dim is None or y_dim is None: y_name, x_name, y_dim_idx, x_dim_idx = _find_spatial_coordinates(data) if x_dim is None: x_dim = x_name if y_dim is None: y_dim = y_name else: # Validate provided dimension names if x_dim not in data.dims: raise ValueError( f"x_dim '{x_dim}' not found in data dimensions: {list(data.dims)}" ) if y_dim not in data.dims: raise ValueError( f"y_dim '{y_dim}' not found in data dimensions: {list(data.dims)}" ) x_dim_idx = data.dims.index(x_dim) y_dim_idx = data.dims.index(y_dim) # Detect cyclic boundary if not specified if cyclic is None: x_coord = data.coords[x_dim] cyclic = _detect_cyclic_longitude(x_coord) if verbose: print(f"Auto-detected cyclic={cyclic} for coordinate '{x_dim}'") # Convert to masked array for gridfill processing data_values = data.values # Handle different types of missing values if hasattr(data_values, "mask"): # Already a masked array masked_data = data_values else: # Create mask from NaN values mask = np.isnan(data_values) if not np.any(mask): # No missing values to fill warnings.warn("No missing values found in input data") return data if not keep_attrs else data.copy() masked_data = ma.array(data_values, mask=mask) # Call the core gridfill function filled_values, converged = gridfill.fill( masked_data, xdim=x_dim_idx, ydim=y_dim_idx, eps=eps, relax=relax, itermax=itermax, initzonal=initzonal, initzonal_linear=initzonal_linear, cyclic=cyclic, initial_value=initial_value, verbose=verbose, ) # Check convergence and issue warnings not_converged = np.logical_not(converged) if np.any(not_converged): warnings.warn( f"gridfill did not converge on {not_converged.sum()} out of " f"{not_converged.size} slices. Consider increasing itermax or " f"relaxing eps tolerance." ) # Create output DataArray preserving coordinates filled_data = xr.DataArray( filled_values, coords=data.coords, dims=data.dims, name=data.name, ) # Preserve attributes if requested if keep_attrs: filled_data.attrs.update(data.attrs) # Add processing history if "history" not in filled_data.attrs: filled_data.attrs["history"] = "" filled_data.attrs[ "history" ] += f"; Filled missing values using gridfill (eps={eps})" return filled_data
[docs] def fill_multiple( datasets: List[DataArray], eps: float, x_dim: Optional[str] = None, y_dim: Optional[str] = None, **kwargs, ) -> List[DataArray]: """ Fill missing values in multiple DataArrays with consistent parameters. This convenience function applies the same gridfill parameters to multiple DataArrays, ensuring consistent processing across related datasets. Parameters ---------- datasets : list of xarray.DataArray List of DataArrays to process eps : float Convergence tolerance for all datasets x_dim : str, optional X-coordinate dimension name (applied to all) y_dim : str, optional Y-coordinate dimension name (applied to all) **kwargs Additional parameters passed to fill() Returns ------- list of xarray.DataArray List of filled DataArrays in same order as input Examples -------- >>> from skyborn.gridfill.xarray import fill_multiple >>> >>> # Fill multiple related variables >>> temp_filled, humid_filled = fill_multiple( ... [temperature_data, humidity_data], ... eps=1e-4, ... verbose=True ... ) """ return [ fill(data, eps=eps, x_dim=x_dim, y_dim=y_dim, **kwargs) for data in datasets ]
[docs] def validate_grid_coverage( data: DataArray, x_dim: Optional[str] = None, y_dim: Optional[str] = None, min_coverage: float = 0.1, ) -> Dict[str, Any]: """ Validate grid data coverage and suitability for gridfill. This function analyzes the input data to determine if it's suitable for gap filling and provides diagnostic information. Parameters ---------- data : xarray.DataArray Input data to analyze x_dim : str, optional X-coordinate dimension name y_dim : str, optional Y-coordinate dimension name min_coverage : float, default 0.1 Minimum fraction of valid data required (0.0 to 1.0) Returns ------- dict Dictionary containing validation results: - 'valid': bool, whether data is suitable for filling - 'coverage': float, fraction of valid data points - 'total_points': int, total number of grid points - 'missing_points': int, number of missing points - 'messages': list, diagnostic messages Examples -------- >>> from skyborn.gridfill.xarray import validate_grid_coverage >>> >>> # Check data quality before filling >>> validation = validate_grid_coverage(data, min_coverage=0.2) >>> if validation['valid']: ... filled = fill(data, eps=1e-4) ... else: ... print("Data quality issues:", validation['messages']) """ # Find spatial coordinates if not provided if x_dim is None or y_dim is None: y_name, x_name, y_dim_idx, x_dim_idx = _find_spatial_coordinates(data) if x_dim is None: x_dim = x_name if y_dim is None: y_dim = y_name messages = [] # Calculate coverage statistics data_values = data.values if hasattr(data_values, "mask"): missing_mask = data_values.mask else: missing_mask = np.isnan(data_values) total_points = missing_mask.size missing_points = np.sum(missing_mask) valid_points = total_points - missing_points coverage = valid_points / total_points if total_points > 0 else 0.0 # Validate coverage valid = True if coverage < min_coverage: valid = False messages.append( f"Insufficient data coverage: {coverage:.1%} < {min_coverage:.1%}" ) # Check for completely empty slices if data.ndim > 2: # For multi-dimensional data, check each 2D slice other_dims = [dim for dim in data.dims if dim not in [x_dim, y_dim]] if other_dims: slice_coverages = [] for coords in data.groupby(other_dims[0]): slice_data = coords[1] slice_missing = ( np.isnan(slice_data.values) if not hasattr(slice_data.values, "mask") else slice_data.values.mask ) slice_coverage = 1.0 - (np.sum(slice_missing) / slice_missing.size) slice_coverages.append(slice_coverage) min_slice_coverage = np.min(slice_coverages) if min_slice_coverage < min_coverage: messages.append( f"Some slices have insufficient coverage (min: {min_slice_coverage:.1%})" ) # Check coordinate regularity x_coord = data.coords[x_dim] y_coord = data.coords[y_dim] # Check for regular spacing if len(x_coord) > 1: x_diffs = np.diff(x_coord.values) if not np.allclose(x_diffs, x_diffs[0], rtol=1e-3): messages.append("X-coordinate spacing is not regular") if len(y_coord) > 1: y_diffs = np.diff(y_coord.values) if not np.allclose(y_diffs, y_diffs[0], rtol=1e-3): messages.append("Y-coordinate spacing is not regular") return { "valid": valid, "coverage": coverage, "total_points": total_points, "missing_points": missing_points, "messages": messages, }