from typing import Tuple, Union
import numpy as np
import xarray as xr
# Lazy imports to avoid loading heavy dependencies at startup
def _get_metpy_calc():
"""Lazy import of metpy.calc to avoid startup overhead"""
import metpy.calc as mpcalc
return mpcalc
def _get_metpy_units():
"""Lazy import of metpy.units to avoid startup overhead"""
from metpy.units import units
return units
def _get_f_regression():
"""Lazy import of sklearn.feature_selection.f_regression"""
from sklearn.feature_selection import f_regression
return f_regression
def _get_pearsonr():
"""Lazy import of scipy.stats.pearsonr for p-value calculation"""
from scipy.stats import pearsonr
return pearsonr
__all__ = [
"linear_regression",
"spatial_correlation",
"convert_longitude_range",
"pearson_correlation",
"spearman_correlation",
"kendall_correlation",
"calculate_potential_temperature",
]
[docs]
def linear_regression(
data: Union[np.ndarray, xr.DataArray], predictor: Union[np.ndarray, xr.DataArray]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Perform linear regression between a 3D data array and a predictor sequence.
Handles both numpy arrays and xarray DataArrays with NaN handling.
Parameters
----------
data : np.ndarray or xr.DataArray
A 3D array of shape (n_samples, dim1, dim2) containing dependent variables.
Missing values should be represented as NaN.
predictor : np.ndarray or xr.DataArray
A 1D array of shape (n_samples,) containing the independent variable.
Missing values should be represented as NaN.
Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing:
- regression_coefficients: The slope of the regression line with shape (dim1, dim2)
- p_values: The p-values of the regression with shape (dim1, dim2)
Raises
------
ValueError
If the number of samples in data doesn't match the length of the predictor.
"""
# Extract numpy arrays regardless of input type
data = getattr(data, "values", data)
predictor = getattr(predictor, "values", predictor)
if len(data) != len(predictor):
raise ValueError(
f"Number of samples in data ({data.shape[0]}) must match "
f"length of predictor ({len(predictor)})"
)
# Check for NaN values in both data and predictor
has_nan = np.any(np.isnan(data)) or np.any(np.isnan(predictor))
# Calculate p-values using lazy import
f_regression = _get_f_regression()
if has_nan:
# Optimize np.nan access - 33% faster than repeated np.nan lookups
undef = np.nan
# Handle NaN case: record locations and replace with 0 in-place
nan_mask_data = np.isnan(data)
nan_mask_predictor = np.isnan(predictor)
data[nan_mask_data] = 0 # Replace NaN with 0 in original array
predictor_work = predictor.copy()
# Replace NaN with 0 in predictor
predictor_work[nan_mask_predictor] = 0
# Create design matrix with predictor and intercept
design_matrix = np.column_stack(
(predictor_work, np.ones(predictor_work.shape[0]))
)
# Get original dimensions and reshape for regression
n_samples, dim1, dim2 = data.shape
data_flat = data.reshape((n_samples, dim1 * dim2))
# Perform linear regression
regression_coef, intercept = np.linalg.lstsq(
design_matrix, data_flat, rcond=None
)[0]
regression_coef, intercept = regression_coef.reshape(
(dim1, dim2)
), intercept.reshape((dim1, dim2))
p_values = f_regression(data_flat, predictor_work)[1].reshape(dim1, dim2)
# Restore original NaN values in data array
data[nan_mask_data] = undef
# Set results back to NaN where original data had too many NaN
# Only consider data NaN, not predictor NaN (predictor NaN affects all gridpoints equally)
nan_mask_gridpoint = np.any(nan_mask_data, axis=0)
regression_coef = np.where(nan_mask_gridpoint, undef, regression_coef)
p_values = np.where(nan_mask_gridpoint, undef, p_values)
else:
# No NaN case: use original efficient algorithm
# Create design matrix with predictor and intercept
design_matrix = np.column_stack((predictor, np.ones(predictor.shape[0])))
# Get original dimensions and reshape for regression
n_samples, dim1, dim2 = data.shape
data_flat = data.reshape((n_samples, dim1 * dim2))
# Perform linear regression
regression_coef, _ = np.linalg.lstsq(design_matrix, data_flat, rcond=None)[0]
regression_coef = regression_coef.reshape((dim1, dim2))
p_values = f_regression(data_flat, predictor)[1].reshape(dim1, dim2)
return regression_coef, p_values
[docs]
def spatial_correlation(
data: Union[np.ndarray, xr.DataArray], predictor: Union[np.ndarray, xr.DataArray]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Perform fast spatial correlation between a 3D data array and a predictor time series.
Optimized for vectorized operations to avoid slow loops over lat/lon grid points.
Parameters
----------
data : np.ndarray or xr.DataArray
A 3D array of shape (time, lat, lon) containing spatial data.
Missing values should be represented as NaN.
predictor : np.ndarray or xr.DataArray
A 1D array of shape (time,) containing the predictor time series.
Missing values should be represented as NaN.
Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing:
- correlation_coefficients: Pearson correlation coefficients with shape (lat, lon)
- p_values: The p-values of the correlations with shape (lat, lon)
Raises
------
ValueError
If the time dimension of data doesn't match the length of the predictor.
"""
# Extract numpy arrays regardless of input type
data = getattr(data, "values", data)
predictor = getattr(predictor, "values", predictor)
if len(data) != len(predictor):
raise ValueError(
f"Time dimension of data ({data.shape[0]}) must match "
f"length of predictor ({len(predictor)})"
)
# Check for NaN values
has_nan = np.any(np.isnan(data)) or np.any(np.isnan(predictor))
if has_nan:
# Optimize np.nan access - 33% faster than repeated np.nan lookups
undef = np.nan
# Handle NaN case: record locations and replace with 0 in-place
nan_mask_data = np.isnan(data)
nan_mask_predictor = np.isnan(predictor)
# Make copies to avoid modifying original data
data_work = data.copy()
predictor_work = predictor.copy()
# Replace NaN with 0 in working arrays
data_work[nan_mask_data] = 0
predictor_work[nan_mask_predictor] = 0
# Get dimensions
n_time, n_lat, n_lon = data_work.shape
# Reshape for vectorized operations
data_flat = data_work.reshape((n_time, n_lat * n_lon))
# Calculate means (only for valid data points)
# Create mask for valid data at each grid point
valid_data_mask = ~nan_mask_data
valid_predictor_mask = ~nan_mask_predictor
# Calculate valid counts for each grid point
valid_counts = np.sum(
valid_data_mask & valid_predictor_mask[:, np.newaxis, np.newaxis], axis=0
)
# Only calculate correlation where we have at least 3 valid pairs
sufficient_data = valid_counts >= 3
# For grid points with sufficient data, calculate means
pred_sum = np.sum(
predictor_work[:, np.newaxis, np.newaxis]
* (valid_data_mask & valid_predictor_mask[:, np.newaxis, np.newaxis]),
axis=0,
)
data_sum = np.sum(
data_work
* (valid_data_mask & valid_predictor_mask[:, np.newaxis, np.newaxis]),
axis=0,
)
# Avoid division by zero
pred_means = np.divide(
pred_sum, valid_counts, out=np.zeros_like(pred_sum), where=valid_counts > 0
)
data_means = np.divide(
data_sum, valid_counts, out=np.zeros_like(data_sum), where=valid_counts > 0
)
# Calculate correlation using vectorized operations
# Center the data (only for valid points)
valid_mask_3d = (
valid_data_mask & valid_predictor_mask[:, np.newaxis, np.newaxis]
)
pred_centered = (
predictor_work[:, np.newaxis, np.newaxis] - pred_means
) * valid_mask_3d
data_centered = (data_work - data_means) * valid_mask_3d
# Vectorized correlation calculation
numerator = np.sum(pred_centered * data_centered, axis=0)
pred_ss = np.sum(pred_centered**2, axis=0)
data_ss = np.sum(data_centered**2, axis=0)
# Calculate correlation coefficients
correlation_coef = np.divide(
numerator,
np.sqrt(pred_ss * data_ss),
out=np.full((n_lat, n_lon), undef),
where=(pred_ss > 0) & (data_ss > 0) & sufficient_data,
)
# Calculate p-values for valid correlations
p_values = np.full((n_lat, n_lon), undef)
valid_r_mask = (
sufficient_data
& (pred_ss > 0)
& (data_ss > 0)
& (np.abs(correlation_coef) < 1.0)
)
if np.any(valid_r_mask):
r_valid = correlation_coef[valid_r_mask]
n_valid = valid_counts[valid_r_mask]
t_stat = r_valid * np.sqrt((n_valid - 2) / (1 - r_valid**2))
from scipy.stats import t
p_valid = 2 * (1 - t.cdf(np.abs(t_stat), n_valid - 2))
p_values[valid_r_mask] = p_valid
# Set p-value to 0 for perfect correlations
perfect_r_mask = (
sufficient_data
& (pred_ss > 0)
& (data_ss > 0)
& (np.abs(correlation_coef) >= 1.0)
)
p_values[perfect_r_mask] = 0.0
# Set results back to NaN where we don't have sufficient valid data
insufficient_mask = ~sufficient_data
correlation_coef[insufficient_mask] = undef
p_values[insufficient_mask] = undef
else:
# No NaN case: use highly optimized vectorized algorithm
n_time, n_lat, n_lon = data.shape
# Reshape for vectorized operations
data_flat = data.reshape((n_time, n_lat * n_lon))
# Calculate means
pred_mean = np.mean(predictor)
data_means = np.mean(data_flat, axis=0)
# Center the data
pred_centered = predictor - pred_mean
data_centered = data_flat - data_means[np.newaxis, :]
# Vectorized correlation calculation across all grid points
numerator = np.sum(pred_centered[:, np.newaxis] * data_centered, axis=0)
pred_ss = np.sum(pred_centered**2)
data_ss = np.sum(data_centered**2, axis=0)
# Avoid division by zero
valid_variance = (pred_ss > 0) & (data_ss > 0)
correlation_coef = np.full(n_lat * n_lon, np.nan)
correlation_coef[valid_variance] = numerator[valid_variance] / np.sqrt(
pred_ss * data_ss[valid_variance]
)
# Calculate p-values vectorized
p_values = np.full(n_lat * n_lon, np.nan)
r_valid = correlation_coef[valid_variance]
# Only calculate p-values where correlation is valid and not perfect
calc_p = valid_variance & (np.abs(correlation_coef) < 1.0)
if np.any(calc_p):
r_calc = correlation_coef[calc_p]
t_stat = r_calc * np.sqrt((n_time - 2) / (1 - r_calc**2))
from scipy.stats import t
p_values[calc_p] = 2 * (1 - t.cdf(np.abs(t_stat), n_time - 2))
# Set p-value to 0 for perfect correlations
p_values[valid_variance & (np.abs(correlation_coef) >= 1.0)] = 0.0
# Reshape back to 2D
correlation_coef = correlation_coef.reshape((n_lat, n_lon))
p_values = p_values.reshape((n_lat, n_lon))
return correlation_coef, p_values
[docs]
def convert_longitude_range(
data: Union[xr.DataArray, xr.Dataset], lon: str = "lon", center_on_180: bool = True
) -> Union[xr.DataArray, xr.Dataset]:
"""
Wrap longitude coordinates of DataArray or Dataset to either -180..179 or 0..359.
Parameters
----------
data : xr.DataArray or xr.Dataset
An xarray DataArray or Dataset object containing longitude coordinates.
lon : str, optional
The name of the longitude coordinate, default is 'lon'.
center_on_180 : bool, optional
If True, wrap longitude from 0..359 to -180..179;
If False, wrap longitude from -180..179 to 0..359.
Returns
-------
xr.DataArray or xr.Dataset
The DataArray or Dataset with wrapped longitude coordinates.
"""
return data.assign_coords(
**{
lon: (
lambda x: (
((x[lon] + 180) % 360 - 180)
if not center_on_180
else (x[lon] % 360)
)
)
}
).sortby(lon, ascending=True)
[docs]
def pearson_correlation(
x: Union[np.ndarray, xr.DataArray], y: Union[np.ndarray, xr.DataArray]
) -> float:
"""
Calculate Pearson correlation coefficient.
Parameters
----------
x, y : array-like
Input data arrays.
Returns
-------
float
Pearson correlation coefficient.
"""
x = getattr(x, "values", x)
y = getattr(y, "values", y)
return np.corrcoef(x.flatten(), y.flatten())[0, 1]
[docs]
def spearman_correlation(
x: Union[np.ndarray, xr.DataArray], y: Union[np.ndarray, xr.DataArray]
) -> float:
"""
Calculate Spearman rank correlation coefficient.
Parameters
----------
x, y : array-like
Input data arrays.
Returns
-------
float
Spearman correlation coefficient.
"""
from scipy.stats import spearmanr
x = getattr(x, "values", x)
y = getattr(y, "values", y)
correlation, _ = spearmanr(x.flatten(), y.flatten())
return correlation
[docs]
def kendall_correlation(
x: Union[np.ndarray, xr.DataArray], y: Union[np.ndarray, xr.DataArray]
) -> float:
"""
Calculate Kendall's tau correlation coefficient.
Parameters
----------
x, y : array-like
Input data arrays.
Returns
-------
float
Kendall's tau correlation coefficient.
"""
from scipy.stats import kendalltau
x = getattr(x, "values", x)
y = getattr(y, "values", y)
correlation, _ = kendalltau(x.flatten(), y.flatten())
return correlation
[docs]
def calculate_potential_temperature(
temperature: Union[np.ndarray, xr.DataArray],
pressure: Union[np.ndarray, xr.DataArray],
reference_pressure: float = 1000.0,
) -> Union[np.ndarray, xr.DataArray]:
"""
Calculate potential temperature using fast numpy operations.
This implementation uses lazy imports and avoids heavy metpy dependencies
for simple potential temperature calculations.
Parameters
----------
temperature : array-like
Temperature values in Kelvin.
pressure : array-like
Pressure values in hPa.
reference_pressure : float, optional
Reference pressure in hPa, default is 1000.0.
Returns
-------
array-like
Potential temperature values in Kelvin.
Notes
-----
Uses the standard formula: theta = T * (P0/P)^(R/cp)
where R/cp = 0.286 for dry air
"""
R_over_cp = 0.286 # R/cp for dry air
potential_temp = temperature * (reference_pressure / pressure) ** R_over_cp
if hasattr(temperature, "attrs"):
if isinstance(potential_temp, np.ndarray):
return xr.DataArray(
potential_temp,
attrs={"units": "K", "long_name": "Potential Temperature"},
)
else:
potential_temp.attrs = {"units": "K", "long_name": "Potential Temperature"}
return potential_temp
def calculate_theta_se(
temperature: Union[np.ndarray, xr.DataArray],
pressure: Union[np.ndarray, xr.DataArray],
mixing_ratio: Union[np.ndarray, xr.DataArray],
dewpoint: Union[np.ndarray, xr.DataArray],
) -> Union[np.ndarray, xr.DataArray]:
"""
Calculate Pseudo-equivalent potential temperature (theta-se)
the fomula is θ_se = T * (1000 / (p - e)) ** (Ra / cpd) * exp(L * r / (cpd * Tc))
from http://stream1.cmatc.cn/cmatcvod/12/tqx/second_content.html
Parameters
----------
temperature : array-like
Temperature in Kelvin.
pressure : array-like
Pressure in hPa.
mixing_ratio : array-like
Water vapor mixing ratio in kg/kg.
dewpoint : array-like
Dewpoint temperature in Celsius.
Returns
-------
array-like
Pseudo-Equivalent potential temperature in Kelvin.
Notes
-----
This function uses MetPy's `vapor_pressure` and `lcl` functions internally.
"""
# Lazy imports
mpcalc = _get_metpy_calc()
units = _get_metpy_units()
# Extract values if xarray
is_xarray = hasattr(temperature, "attrs")
if is_xarray:
temp_values = temperature.values
pres_values = pressure.values
mixr_values = mixing_ratio.values
dewp_values = dewpoint.values
else:
temp_values = temperature
pres_values = pressure
mixr_values = mixing_ratio
dewp_values = dewpoint
# Convert units
p = pres_values * units.hPa
T = temp_values * units.kelvin
r = mixr_values * units("kg/kg")
td = dewp_values * units.degC
# Convert mixing ratio to g/kg for vapor_pressure
r_gkg = r.to("g/kg")
# Calculate vapor pressure
e = mpcalc.vapor_pressure(p, r_gkg)
# Calculate LCL temperature
_, T_lcl = mpcalc.lcl(p, T, td)
# Convert LCL temperature to Kelvin
T_lcl_K = T_lcl.to("kelvin").magnitude
# Constants
Rd = 287.0 # J/kg/K
cp_d = 1004.0 # J/kg/K
L = 2.5e6 # J/kg
# Dry air pressure
p_dry = p - e
# Theta_e calculation
theta_e_part = temp_values * (1000.0 / p_dry) ** (Rd / cp_d)
latent_part = np.exp((L * mixr_values) / (cp_d * T_lcl_K))
theta_se = theta_e_part * latent_part
# Extract magnitude if result is Quantity
if hasattr(theta_se, "magnitude"):
theta_se = theta_se.magnitude
# Preserve xarray structure if input is xarray
if is_xarray:
return xr.DataArray(
theta_se,
dims=temperature.dims,
coords=temperature.coords,
attrs={
"units": "K",
"long_name": "Pseudo-Equivalent Potential Temperature",
},
)
return theta_se