Source code for skyborn.interp.regridding

"""
This scripts contains functions that performs nearest, bilinear, and conservative interpolation
on xarray.Datasets. The original version of this script is available at WeatherBench2.

Qianye Su
suqianye2000@gmail.com

Reference
 - WeatherBench2 regridding:
     https://github.com/google-research/weatherbench2/blob/main/weatherbench2/regridding.py
"""

from __future__ import annotations

import dataclasses
from typing import Optional, Tuple, Union

import numpy as np
import xarray
from sklearn import neighbors

from . import regrid as _native_regrid

_native_nearest_neighbor_indices = _native_regrid.nearest_neighbor_indices
_native_nearest_regrid_apply = _native_regrid.nearest_regrid_apply
_native_bilinear_regrid = _native_regrid.bilinear_regrid
_native_bilinear_regrid_nd = _native_regrid.bilinear_regrid_nd
_native_conservative_latitude_weights = _native_regrid.conservative_latitude_weights
_native_conservative_longitude_weights = _native_regrid.conservative_longitude_weights
_native_conservative_regrid = _native_regrid.conservative_regrid


# Keep BallTree as the default nearest-neighbor backend until the compiled helper
# matches sklearn's haversine tie-breaking on midpoint-heavy regular grids.
_ENABLE_NATIVE_NEAREST = False

__all__ = [
    "Grid",
    "Regridder",
    "NearestRegridder",
    "BilinearRegridder",
    "ConservativeRegridder",
    "nearest_neighbor_indices",
    "regrid_dataset",
]

Array = Union[np.ndarray]


def _is_strictly_increasing_1d(values: np.ndarray) -> bool:
    """Return True when a 1D coordinate array is strictly increasing."""
    return bool(np.all(np.diff(values) > 0))


def _supports_native_regular_grid(source: Grid, target: Grid) -> bool:
    """Check whether both grids satisfy the monotone 1D requirements of the native helpers."""
    return (
        len(source.lon) > 1
        and len(source.lat) > 1
        and len(target.lon) > 0
        and len(target.lat) > 0
        and _is_strictly_increasing_1d(source.lon)
        and _is_strictly_increasing_1d(source.lat)
        and _is_strictly_increasing_1d(target.lon)
        and _is_strictly_increasing_1d(target.lat)
    )


def _detect_coordinate_names(dataset: xarray.Dataset) -> Tuple[str, str]:
    """
    Detect latitude and longitude coordinate names in the dataset.

    Args:
        dataset: xarray Dataset

    Returns:
        Tuple of (longitude_name, latitude_name)

    Raises:
        ValueError: If coordinate names cannot be detected
    """
    # Common variations of coordinate names
    lon_names = ["longitude", "lon", "long", "x"]
    lat_names = ["latitude", "lat", "y"]

    # Find longitude coordinate
    lon_coord = None
    for name in lon_names:
        if name in dataset.dims:
            lon_coord = name
            break

    # Find latitude coordinate
    lat_coord = None
    for name in lat_names:
        if name in dataset.dims:
            lat_coord = name
            break

    if lon_coord is None or lat_coord is None:
        available_dims = list(dataset.sizes.keys())
        raise ValueError(
            f"Could not detect longitude/latitude coordinates. "
            f"Available dimensions: {available_dims}. "
            f"Expected one of: lon={lon_names}, lat={lat_names}"
        )

    return lon_coord, lat_coord


def _grid_allclose(left: Grid, right: Grid) -> bool:
    """Return True when two grids describe the same 1D lon/lat coordinates."""
    return bool(
        left.lon.shape == right.lon.shape
        and left.lat.shape == right.lat.shape
        and np.allclose(left.lon, right.lon)
        and np.allclose(left.lat, right.lat)
    )


def _has_partial_spatial_dims(
    variable: xarray.DataArray, lon_dim: str, lat_dim: str
) -> bool:
    """Return True when a variable depends on only one horizontal dimension."""
    has_lon = lon_dim in variable.dims
    has_lat = lat_dim in variable.dims
    return has_lon != has_lat


def _validate_no_partial_spatial_variables(
    dataset: xarray.Dataset, lon_dim: str, lat_dim: str
) -> None:
    """Reject variables that would be misaligned after replacing the target grid."""
    for name, variable in dataset.data_vars.items():
        if _has_partial_spatial_dims(variable, lon_dim, lat_dim):
            raise ValueError(
                f"Variable {name!r} has only one horizontal dimension. "
                f"Variables must contain both {lon_dim!r} and {lat_dim!r}, or neither."
            )

    for name, coordinate in dataset.coords.items():
        if name in {lon_dim, lat_dim}:
            continue
        if _has_partial_spatial_dims(coordinate, lon_dim, lat_dim):
            raise ValueError(
                f"Coordinate {name!r} has only one horizontal dimension. "
                f"Coordinates must contain both {lon_dim!r} and {lat_dim!r}, or neither."
            )


def _select_regridder(
    source_grid: Grid,
    target_grid: Grid,
    method: str,
) -> Regridder:
    """Construct the regridder for a supported method."""
    if method == "nearest":
        return NearestRegridder(source_grid, target_grid)
    if method == "bilinear":
        return BilinearRegridder(source_grid, target_grid)
    if method == "conservative":
        return ConservativeRegridder(source_grid, target_grid)
    raise ValueError(
        f"Unknown method: {method}. Choose from 'nearest', 'bilinear', 'conservative'"
    )


[docs] @dataclasses.dataclass(frozen=True) class Grid: """Representation of a rectilinear grid.""" lon: np.ndarray lat: np.ndarray
[docs] @classmethod def from_degrees(cls, lon: np.ndarray, lat: np.ndarray) -> Grid: return cls(np.deg2rad(lon), np.deg2rad(lat))
[docs] @classmethod def from_dataset(cls, dataset: xarray.Dataset) -> Grid: """Create a Grid from an xarray Dataset by auto-detecting coordinates.""" lon_name, lat_name = _detect_coordinate_names(dataset) lon_values = dataset[lon_name].values lat_values = dataset[lat_name].values return cls.from_degrees(lon_values, lat_values)
@property def shape(self) -> tuple[int, int]: return (len(self.lon), len(self.lat)) def _to_tuple(self) -> tuple[tuple[float, ...], tuple[float, ...]]: return tuple(self.lon.tolist()), tuple(self.lat.tolist()) def __eq__(self, other): # needed for hashability return isinstance(other, Grid) and self._to_tuple() == other._to_tuple() def __hash__(self): return hash(self._to_tuple())
[docs] @dataclasses.dataclass(frozen=True) class Regridder: """Base class for regridding.""" source: Grid target: Grid
[docs] def regrid_array(self, field: Array) -> np.ndarray: """Regrid an array with dimensions (..., lon, lat) from source to target.""" raise NotImplementedError
[docs] def regrid_dataset( self, dataset: xarray.Dataset, lon_dim: Optional[str] = None, lat_dim: Optional[str] = None, ) -> xarray.Dataset: """ Regrid an xarray.Dataset from source to target. Args: dataset: Input xarray Dataset lon_dim: Name of longitude dimension (auto-detected if None) lat_dim: Name of latitude dimension (auto-detected if None) Returns: Regridded xarray Dataset with preserved dimension order """ # Auto-detect coordinate names if not provided if lon_dim is None or lat_dim is None: detected_lon, detected_lat = _detect_coordinate_names(dataset) lon_dim = lon_dim or detected_lon lat_dim = lat_dim or detected_lat # Store original dimension order for each variable original_dims = {} for var_name in dataset.data_vars: original_dims[var_name] = list(dataset[var_name].dims) # Ensure latitude is in ascending order lat_diff = dataset[lat_dim].diff(lat_dim) if not (lat_diff > 0).all(): if not (lat_diff < 0).all(): raise ValueError( f"Latitude coordinate {lat_dim!r} must be strictly monotonic" ) dataset = dataset.isel({lat_dim: slice(None, None, -1)}) # Reverse assert (dataset[lat_dim].diff(lat_dim) > 0).all() _validate_no_partial_spatial_variables(dataset, lon_dim, lat_dim) dataset_source_grid = Grid.from_degrees( dataset[lon_dim].values, dataset[lat_dim].values, ) active_regridder = self if not _grid_allclose(self.source, dataset_source_grid): reversed_source_grid = Grid(self.source.lon, self.source.lat[::-1]) if not _grid_allclose(reversed_source_grid, dataset_source_grid): raise ValueError( "Regridder source grid does not match dataset coordinates" ) active_regridder = self.__class__(dataset_source_grid, self.target) # Create target grid coordinates target_lon_deg = np.rad2deg(self.target.lon) target_lat_deg = np.rad2deg(self.target.lat) # Process each variable separately to maintain dimension order regridded_vars = {} for var_name, var in dataset.data_vars.items(): if lon_dim in var.dims and lat_dim in var.dims: # Apply regridding with proper dimension handling regridded_var = xarray.apply_ufunc( active_regridder.regrid_array, var, input_core_dims=[[lon_dim, lat_dim]], output_core_dims=[[lon_dim, lat_dim]], exclude_dims={lon_dim, lat_dim}, vectorize=True, dask="allowed", output_dtypes=[var.dtype], keep_attrs=True, ) # Update coordinates while preserving dimension order regridded_var = regridded_var.assign_coords( {lon_dim: target_lon_deg, lat_dim: target_lat_deg} ) # Ensure original dimension order is maintained current_dims = list(regridded_var.dims) target_dims = original_dims[var_name].copy() # Update spatial dimensions in target_dims for i, dim in enumerate(target_dims): if dim == lon_dim: target_dims[i] = lon_dim elif dim == lat_dim: target_dims[i] = lat_dim # Transpose to match original order if needed if current_dims != target_dims: regridded_var = regridded_var.transpose(*target_dims) regridded_vars[var_name] = regridded_var else: # Variables without spatial dimensions remain unchanged regridded_vars[var_name] = var # Create new dataset with regridded variables regridded_dataset = xarray.Dataset( regridded_vars, coords={ **{ k: v for k, v in dataset.coords.items() if k not in [lon_dim, lat_dim] }, lon_dim: target_lon_deg, lat_dim: target_lat_deg, }, attrs=dataset.attrs, ) return regridded_dataset
[docs] def nearest_neighbor_indices(source_grid: Grid, target_grid: Grid) -> np.ndarray: """Returns Haversine nearest neighbor indices from source_grid to target_grid.""" if _ENABLE_NATIVE_NEAREST and _native_nearest_neighbor_indices is not None: return _native_nearest_neighbor_indices( source_grid.lon, source_grid.lat, target_grid.lon, target_grid.lat, ) # Construct a BallTree to find nearest neighbors on the sphere source_mesh = np.meshgrid(source_grid.lon, source_grid.lat, indexing="ij") target_mesh = np.meshgrid(target_grid.lon, target_grid.lat, indexing="ij") index_coords = np.stack([source_mesh[1].ravel(), source_mesh[0].ravel()], axis=-1) query_coords = np.stack([target_mesh[1].ravel(), target_mesh[0].ravel()], axis=-1) tree = neighbors.BallTree(index_coords, metric="haversine") indices = tree.query(query_coords, return_distance=False).squeeze(axis=-1) return indices
def _gather_flat_spatial( field_arr: np.ndarray, indices: np.ndarray, source_shape: tuple[int, int], target_shape: tuple[int, int], ) -> np.ndarray: """Apply flat spatial indices across arbitrary leading dimensions.""" src_size = source_shape[0] * source_shape[1] flat = field_arr.reshape(-1, src_size) gathered = np.take(flat, indices, axis=1) return gathered.reshape(field_arr.shape[:-2] + target_shape)
[docs] class NearestRegridder(Regridder): """Regrid with nearest neighbor interpolation."""
[docs] def __init__(self, source: Grid, target: Grid): super().__init__(source, target) self._indices = None
@property def indices(self): """The interpolation indices associated with source_grid.""" if self._indices is None: self._indices = nearest_neighbor_indices(self.source, self.target) return self._indices def _nearest_neighbor_2d(self, array: Array) -> np.ndarray: """2D nearest neighbor interpolation using BallTree with optimized indexing.""" if array.shape != self.source.shape: raise ValueError( f"Expected array.shape={array.shape} to match source.shape={self.source.shape}" ) # Use advanced indexing for better performance array_flat = array.ravel() interpolated = array_flat[self.indices] return interpolated.reshape(self.target.shape) def _nearest_neighbor_nd(self, field: Array) -> np.ndarray: """Apply the cached flat indices across leading dimensions.""" field_arr = np.asarray(field) if field_arr.shape[-2:] != self.source.shape: raise ValueError( f"Expected field shape {self.source.shape}, got {field_arr.shape[-2:]}" ) if _native_nearest_regrid_apply is not None: return _native_nearest_regrid_apply( field_arr, self.indices, self.target.shape[0], self.target.shape[1], ) return _gather_flat_spatial( field_arr, self.indices, self.source.shape, self.target.shape, )
[docs] def regrid_array(self, field: Array) -> np.ndarray: return self._nearest_neighbor_nd(field)
[docs] class BilinearRegridder(Regridder): """Regrid with bilinear interpolation.""" def _bilinear_2d(self, field: Array) -> np.ndarray: lat_source = self.source.lat lat_target = self.target.lat lon_source = self.source.lon lon_target = self.target.lon # Ensure the field has the correct shape (lon, lat) if field.shape != (len(lon_source), len(lat_source)): raise ValueError( f"Expected field shape {(len(lon_source), len(lat_source))}, " f"got {field.shape}" ) if _native_bilinear_regrid is not None and _supports_native_regular_grid( self.source, self.target ): return _native_bilinear_regrid( np.asarray(field, dtype=np.float64), lon_source, lat_source, lon_target, lat_target, ) # Interpolate over latitude first (for each longitude) lat_interp = np.zeros((len(lon_source), len(lat_target))) for i, lon_slice in enumerate(field): lat_interp[i, :] = np.interp(lat_target, lat_source, lon_slice) # Interpolate over longitude (for each target latitude) result = np.zeros((len(lon_target), len(lat_target))) for j in range(len(lat_target)): result[:, j] = np.interp(lon_target, lon_source, lat_interp[:, j]) return result def _bilinear_nd(self, field: Array) -> np.ndarray: """Apply bilinear interpolation across arbitrary leading dimensions.""" field_arr = np.asarray(field) if field_arr.shape[-2:] != self.source.shape: raise ValueError( f"Expected field shape {self.source.shape}, got {field_arr.shape[-2:]}" ) if field_arr.ndim == 2: return self._bilinear_2d(field_arr) if _native_bilinear_regrid_nd is not None and _supports_native_regular_grid( self.source, self.target ): return _native_bilinear_regrid_nd( np.asarray(field_arr, dtype=np.float64), self.source.lon, self.source.lat, self.target.lon, self.target.lat, ) leading_shape = field_arr.shape[:-2] result = np.empty(leading_shape + self.target.shape, dtype=np.float64) for index in np.ndindex(leading_shape): result[index] = self._bilinear_2d(field_arr[index]) return result
[docs] def regrid_array(self, field: Array) -> np.ndarray: return self._bilinear_nd(field)
def _assert_increasing(x: np.ndarray) -> None: if not (np.diff(x) > 0).all(): raise ValueError(f"Array is not increasing: {x}") def _latitude_cell_bounds(x: Array) -> np.ndarray: pi_over_2 = np.array([np.pi / 2], dtype=x.dtype) return np.concatenate((-pi_over_2, (x[:-1] + x[1:]) / 2, pi_over_2)) def _latitude_overlap( source_points: Array, target_points: Array, ) -> np.ndarray: """Calculate the area overlap as a function of latitude.""" source_bounds = _latitude_cell_bounds(source_points) target_bounds = _latitude_cell_bounds(target_points) upper = np.minimum(target_bounds[1:, np.newaxis], source_bounds[np.newaxis, 1:]) lower = np.maximum(target_bounds[:-1, np.newaxis], source_bounds[np.newaxis, :-1]) # Normalized cell area: integral from lower to upper of cos(latitude) overlap = (upper > lower) * (np.sin(upper) - np.sin(lower)) return overlap def _conservative_latitude_weights( source_points: Array, target_points: Array ) -> np.ndarray: """Create a weight matrix for conservative regridding along latitude. Args: source_points: 1D latitude coordinates in radians for centers of source cells. target_points: 1D latitude coordinates in radians for centers of target cells. Returns: NumPy array with shape (target_size, source_size). Rows sum to 1. """ _assert_increasing(source_points) _assert_increasing(target_points) if _native_conservative_latitude_weights is not None: return _native_conservative_latitude_weights(source_points, target_points) weights = _latitude_overlap(source_points, target_points) # Handle zero-sum rows to avoid division by zero row_sums = np.sum(weights, axis=1, keepdims=True) # Avoid in-place division which causes broadcasting issues result = np.copy(weights) for i in range(result.shape[0]): if row_sums[i, 0] > 1e-15: result[i, :] /= row_sums[i, 0] else: # For zero-sum rows, distribute weight equally result[i, :] = 1.0 / result.shape[1] return result def _align_phase_with(x, target, period): """Align the phase of a periodic number to match another.""" shift_down = x > target + period / 2 shift_up = x < target - period / 2 return x + period * shift_up - period * shift_down def _periodic_upper_bounds(x, period): x_plus = _align_phase_with(np.roll(x, -1), x, period) return (x + x_plus) / 2 def _periodic_lower_bounds(x, period): x_minus = _align_phase_with(np.roll(x, +1), x, period) return (x_minus + x) / 2 def _periodic_overlap(x0, x1, y0, y1, period): """Calculate the overlap between two intervals considering periodicity.""" y0 = _align_phase_with(y0, x0, period) y1 = _align_phase_with(y1, x0, period) upper = np.minimum(x1, y1) lower = np.maximum(x0, y0) return np.maximum(upper - lower, 0) def _longitude_overlap( first_points: Array, second_points: Array, period: float = 2 * np.pi, ) -> np.ndarray: """Calculate the area overlap as a function of longitude.""" first_points = first_points % period first_upper = _periodic_upper_bounds(first_points, period) first_lower = _periodic_lower_bounds(first_points, period) second_points = second_points % period second_upper = _periodic_upper_bounds(second_points, period) second_lower = _periodic_lower_bounds(second_points, period) x0 = first_lower[:, np.newaxis] x1 = first_upper[:, np.newaxis] y0 = second_lower[np.newaxis, :] y1 = second_upper[np.newaxis, :] overlap_func = np.vectorize(_periodic_overlap, excluded=["period"]) overlap = overlap_func(x0, x1, y0, y1, period=period) return overlap def _conservative_longitude_weights( source_points: np.ndarray, target_points: np.ndarray ) -> np.ndarray: """Create a weight matrix for conservative regridding along longitude. Args: source_points: 1D longitude coordinates in radians for centers of source cells. target_points: 1D longitude coordinates in radians for centers of target cells. Returns: NumPy array with shape (target_size, source_size). Rows sum to 1. """ _assert_increasing(source_points) _assert_increasing(target_points) if _native_conservative_longitude_weights is not None: return _native_conservative_longitude_weights(source_points, target_points) weights = _longitude_overlap(target_points, source_points) # Handle zero-sum rows to avoid division by zero row_sums = np.sum(weights, axis=1, keepdims=True) nonzero_mask = row_sums > 1e-15 # Avoid in-place division which causes broadcasting issues result = np.copy(weights) for i in range(result.shape[0]): if nonzero_mask[i, 0]: result[i, :] /= row_sums[i, 0] else: # For zero-sum rows, distribute weight equally result[i, :] = 1.0 / result.shape[1] return result
[docs] class ConservativeRegridder(Regridder): """Regrid with linear conservative regridding."""
[docs] def __init__(self, source: Grid, target: Grid): super().__init__(source, target) # Pre-compute weights for better performance self._lon_weights = None self._lat_weights = None
@property def lon_weights(self): """Cached longitude weights for performance.""" if self._lon_weights is None: self._lon_weights = _conservative_longitude_weights( self.source.lon, self.target.lon ) return self._lon_weights @property def lat_weights(self): """Cached latitude weights for performance.""" if self._lat_weights is None: self._lat_weights = _conservative_latitude_weights( self.source.lat, self.target.lat ) return self._lat_weights def _mean(self, field: Array) -> np.ndarray: """Computes cell-averages of field on the target grid with optimized einsum.""" # Use cached weights for better performance result = np.einsum( "ac,bd,...cd->...ab", self.lon_weights, self.lat_weights, field, optimize=True, ) return result def _python_nanmean(self, field: Array) -> np.ndarray: """Compute cell-averages skipping NaNs using the Python fallback path.""" nulls = np.isnan(field) total = self._mean(np.where(nulls, 0, field)) count = self._mean(~nulls) with np.errstate(divide="ignore", invalid="ignore"): result = np.true_divide(total, count) result[count == 0] = np.nan # Set divisions by zero to NaN return result def _conservative_2d(self, field: Array) -> np.ndarray: """Apply conservative regridding to a single 2D lon/lat field.""" field_arr = np.asarray(field) if field_arr.shape != self.source.shape: raise ValueError( f"Expected field shape {self.source.shape}, got {field_arr.shape}" ) if _native_conservative_regrid is not None: return _native_conservative_regrid( np.asarray(field_arr, dtype=np.float64), self.lon_weights, self.lat_weights, ) return self._python_nanmean(field_arr) def _conservative_nd(self, field: Array) -> np.ndarray: """Apply conservative regridding across arbitrary leading dimensions.""" field_arr = np.asarray(field) if field_arr.shape[-2:] != self.source.shape: raise ValueError( f"Expected field shape {self.source.shape}, got {field_arr.shape[-2:]}" ) if field_arr.ndim == 2: return self._conservative_2d(field_arr) if _native_conservative_regrid is not None: return _native_conservative_regrid( np.asarray(field_arr, dtype=np.float64), self.lon_weights, self.lat_weights, ) return self._python_nanmean(field_arr)
[docs] def regrid_array(self, field: Array) -> np.ndarray: return self._conservative_nd(field)
# Convenience function for easy regridding
[docs] def regrid_dataset( dataset: xarray.Dataset, target_grid: Grid, method: str = "bilinear", lon_dim: Optional[str] = None, lat_dim: Optional[str] = None, ) -> xarray.Dataset: """ Convenience function to regrid a dataset with optimized performance. Args: dataset: Input xarray Dataset target_grid: Target grid for regridding method: Interpolation method ('nearest', 'bilinear', 'conservative') lon_dim: Name of longitude dimension (auto-detected if None) lat_dim: Name of latitude dimension (auto-detected if None) Returns: Regridded xarray Dataset with preserved dimension order """ # Create source grid from dataset source_grid = Grid.from_dataset(dataset) return _select_regridder(source_grid, target_grid, method).regrid_dataset( dataset, lon_dim=lon_dim, lat_dim=lat_dim, )