Source code for skyborn.plot.vector

"""Unified curly-vector API and engine surface for Skyborn plots.

Author: Qianye Su <suqianye2000@gmail.com>
Copyright (c) 2025-2026 Qianye Su
Created: 2026-03-01 14:58:56
"""

from __future__ import annotations

from collections.abc import Hashable
from functools import partial
from typing import TYPE_CHECKING, Any

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from ._adapters import cartopy_vector as _cartopy_adapter
from ._adapters import curly_vector_entry as _entry_adapter
from ._adapters import dataset_vector as _dataset_adapter
from ._adapters import grid_prepare as _grid_prepare_adapter
from ._artists import vector_artists as _artist_helpers
from ._artists.vector_key_artist import CurlyVectorKey
from ._core import geometry as _geometry
from ._core import native as _native_helpers
from ._core import thinning as _thinning
from ._core import vector_engine as _vector_engine
from ._core.result import CurlyVectorPlotSet
from ._shared.axes import _is_cartopy_crs_like, _looks_like_axes
from ._shared.coords import (
    _axis_coordinate_1d,
    _axis_is_uniform,
    _coerce_matching_plot_field,
    _extract_meshgrid_axes,
    _filled_float_array,
    _normalize_regular_grid_orientation,
)
from ._shared.style import (
    _collect_named_kwargs,
    _normalize_artist_alpha,
    _normalize_supported_arrowstyle,
    _resolve_curly_anchor_alias,
    _resolve_curly_style_aliases,
)

if TYPE_CHECKING:  # pragma: no cover
    from matplotlib.axes import Axes  # pragma: no cover

__all__ = [
    "CurlyVectorKey",
    "CurlyVectorPlotSet",
    "curly_vector",
    "curly_vector_key",
]

_display_points_to_data = _geometry._display_points_to_data
_local_display_jacobian = _geometry._local_display_jacobian
_NCLNativeTraceContext = _thinning._NCLNativeTraceContext
_prepare_ncl_display_sampler = _thinning._prepare_ncl_display_sampler
Grid = _vector_engine.Grid

_ARRAY_CURLY_VECTOR_KWARG_NAMES = (
    "density",
    "linewidth",
    "color",
    "vmin",
    "vmax",
    "cmap",
    "norm",
    "alpha",
    "facecolor",
    "edgecolor",
    "rasterized",
    "arrowsize",
    "arrowstyle",
    "transform",
    "zorder",
    "start_points",
    "integration_direction",
    "grains",
    "broken_streamlines",
    "anchor",
    "pivot",
    "ref_magnitude",
    "ref_length",
    "min_frac_length",
    "min_distance",
    "ncl_preset",
)
_CURLY_VECTOR_NCL_KWARG_NAMES = (
    "density",
    "linewidth",
    "color",
    "vmin",
    "vmax",
    "cmap",
    "norm",
    "alpha",
    "facecolor",
    "edgecolor",
    "rasterized",
    "arrowsize",
    "arrowstyle",
    "transform",
    "zorder",
    "start_points",
    "integration_direction",
    "grains",
    "broken_streamlines",
    "anchor",
    "ref_magnitude",
    "ref_length",
    "min_frac_length",
    "min_distance",
    "allow_non_uniform_grid",
    "ncl_preset",
)

_extract_curly_vector_dataset_source = (
    _dataset_adapter._extract_curly_vector_dataset_source
)
_prepare_dataset_style_field = _dataset_adapter._prepare_dataset_style_field

_default_cartopy_target_extent = _cartopy_adapter._default_cartopy_target_extent
_normalize_regrid_shape = _grid_prepare_adapter._normalize_regrid_shape
_is_curvilinear_grid = _grid_prepare_adapter._is_curvilinear_grid
_default_curvilinear_regrid_shape = (
    _grid_prepare_adapter._default_curvilinear_regrid_shape
)
_maybe_as_scalar_field = _grid_prepare_adapter._maybe_as_scalar_field
_regrid_curvilinear_vectors = _grid_prepare_adapter._regrid_curvilinear_vectors
_extract_regular_grid_from_regridded_vectors = (
    _cartopy_adapter._extract_regular_grid_from_regridded_vectors
)
_regrid_cartopy_vectors = _cartopy_adapter._regrid_cartopy_vectors

from .nclcurly_native import (
    build_filled_arrow_polygons as _build_filled_arrow_polygons_native,
)
from .nclcurly_native import (
    build_open_arrow_segments as _build_open_arrow_segments_native,
)
from .nclcurly_native import sample_grid_field as _sample_grid_field_native
from .nclcurly_native import sample_grid_field_array as _sample_grid_field_array_native
from .nclcurly_native import (
    thin_display_candidates as _thin_ncl_display_candidates_native,
)
from .nclcurly_native import (
    thin_mapped_candidates as _thin_ncl_mapped_candidates_native,
)
from .nclcurly_native import trace_ncl_direction as _trace_ncl_direction_native
from .nclcurly_native import validate_display_curve as _validate_display_curve_native


def _normalize_ncl_preset(ncl_preset):
    """Normalize supported NCL-like preset names."""
    if ncl_preset is None:
        return None

    preset = str(ncl_preset).strip().lower()
    if preset in {"profile", "vertical_profile", "vertical-profile", "lat_pressure"}:
        return "profile"

    raise ValueError(f"Unsupported ncl_preset {ncl_preset!r}")


def _resolve_default_ncl_preset(x, y, allow_non_uniform_grid, ncl_preset):
    preset = _normalize_ncl_preset(ncl_preset)
    if preset is not None:
        return True if preset == "profile" else allow_non_uniform_grid, preset

    x_axis = _axis_coordinate_1d(x, "x")
    y_axis = _axis_coordinate_1d(y, "y")
    if x_axis is None or y_axis is None:
        return allow_non_uniform_grid, preset

    if not _axis_is_uniform(x_axis) or not _axis_is_uniform(y_axis):
        return True, "profile"
    return allow_non_uniform_grid, preset


def _infer_profile_ncl_ref_magnitude(u, v, percentile=97.0):
    """Estimate a conservative reference magnitude for lat-pressure sections."""
    magnitude = np.hypot(np.asarray(u, dtype=float), np.asarray(v, dtype=float))
    valid = magnitude[np.isfinite(magnitude)]
    valid = valid[valid > 0.0]

    if valid.size == 0:
        return None

    ref_magnitude = float(np.nanpercentile(valid, float(percentile)))
    if not np.isfinite(ref_magnitude) or ref_magnitude <= 0.0:
        ref_magnitude = float(np.nanmax(valid))

    return ref_magnitude if np.isfinite(ref_magnitude) and ref_magnitude > 0.0 else None


def _apply_ncl_preset_defaults(
    ncl_preset,
    allow_non_uniform_grid,
    ref_magnitude,
    ref_length,
    min_distance,
    u,
    v,
):
    """Apply optional NCL-style presets without changing the global defaults."""
    preset = _normalize_ncl_preset(ncl_preset)
    if preset != "profile":
        return allow_non_uniform_grid, ref_magnitude, ref_length, min_distance, preset

    allow_non_uniform_grid = True
    if ref_magnitude is None:
        ref_magnitude = _infer_profile_ncl_ref_magnitude(u, v)
    if ref_length is None:
        ref_length = 0.06

    return allow_non_uniform_grid, ref_magnitude, ref_length, min_distance, preset


def _regrid_non_uniform_vectors_to_uniform(
    x: Any, y: Any, u: Any, v: Any, *scalars: Any
) -> tuple[np.ndarray, ...]:
    try:
        from scipy.interpolate import RegularGridInterpolator
    except ImportError as err:
        raise ImportError(
            "scipy is required for non-uniform grid support. Please install scipy."
        ) from err

    x_axis, y_axis = _extract_meshgrid_axes(x, y)
    u_values = _filled_float_array(u)
    v_values = _filled_float_array(v)
    scalar_values = [_filled_float_array(field) for field in scalars]
    expected_shape = (y_axis.size, x_axis.size)
    if u_values.shape != expected_shape or v_values.shape != expected_shape:
        raise ValueError(
            f"u and v must match the non-uniform grid shape {expected_shape}"
        )
    for scalar_field in scalar_values:
        if scalar_field.shape != expected_shape:
            raise ValueError(
                "Non-uniform scalar style fields must match the source vector-grid "
                f"shape {expected_shape}"
            )

    x_sorted = x_axis.copy()
    y_sorted = y_axis.copy()
    u_sorted = u_values.copy()
    v_sorted = v_values.copy()
    scalar_sorted = [field.copy() for field in scalar_values]

    if x_sorted.size > 1 and np.any(np.diff(x_sorted) <= 0):
        x_idx = np.argsort(x_sorted)
        x_sorted = x_sorted[x_idx]
        u_sorted = u_sorted[:, x_idx]
        v_sorted = v_sorted[:, x_idx]
        scalar_sorted = [field[:, x_idx] for field in scalar_sorted]
    if y_sorted.size > 1 and np.any(np.diff(y_sorted) <= 0):
        y_idx = np.argsort(y_sorted)
        y_sorted = y_sorted[y_idx]
        u_sorted = u_sorted[y_idx, :]
        v_sorted = v_sorted[y_idx, :]
        scalar_sorted = [field[y_idx, :] for field in scalar_sorted]

    if x_sorted.size > 1 and np.any(np.diff(x_sorted) <= 0):
        raise ValueError("x coordinates must be strictly monotonic after sorting")
    if y_sorted.size > 1 and np.any(np.diff(y_sorted) <= 0):
        raise ValueError("y coordinates must be strictly monotonic after sorting")

    x_uniform = np.linspace(
        float(np.nanmin(x_sorted)), float(np.nanmax(x_sorted)), x_sorted.size
    )
    y_uniform = np.linspace(
        float(np.nanmin(y_sorted)), float(np.nanmax(y_sorted)), y_sorted.size
    )
    X_uniform, Y_uniform = np.meshgrid(x_uniform, y_uniform, indexing="xy")
    points = np.column_stack([Y_uniform.ravel(), X_uniform.ravel()])

    u_interp = RegularGridInterpolator(
        (y_sorted, x_sorted),
        u_sorted,
        method="linear",
        bounds_error=False,
        fill_value=np.nan,
    )
    v_interp = RegularGridInterpolator(
        (y_sorted, x_sorted),
        v_sorted,
        method="linear",
        bounds_error=False,
        fill_value=np.nan,
    )
    scalar_interps = [
        RegularGridInterpolator(
            (y_sorted, x_sorted),
            field,
            method="linear",
            bounds_error=False,
            fill_value=np.nan,
        )
        for field in scalar_sorted
    ]

    u_uniform = u_interp(points).reshape(Y_uniform.shape)
    v_uniform = v_interp(points).reshape(Y_uniform.shape)
    scalar_uniform = [
        interp(points).reshape(Y_uniform.shape) for interp in scalar_interps
    ]
    return (x_uniform, y_uniform, u_uniform, v_uniform, *scalar_uniform)


def _array_curly_vector(
    axes: Any,
    x: Any,
    y: Any,
    u: Any,
    v: Any,
    density: Any = 1,
    linewidth: Any = None,
    linewidths: Any = None,
    color: Any = None,
    c: Any = None,
    cmap: Any = None,
    norm: Any = None,
    vmin: float | None = None,
    vmax: float | None = None,
    alpha: float | None = None,
    facecolor: Any = None,
    facecolors: Any = None,
    edgecolor: Any = None,
    edgecolors: Any = None,
    rasterized: bool | None = None,
    arrowsize: float = 1,
    arrowstyle: str = "->",
    transform: Any = None,
    zorder: float | None = None,
    start_points: Any = None,
    integration_direction: str = "both",
    grains: Any = 15,
    broken_streamlines: bool = True,
    allow_non_uniform_grid: bool = False,
    anchor: str | None = None,
    pivot: str | None = None,
    ref_magnitude: float | None = None,
    ref_length: float | None = None,
    min_frac_length: float = 0.0,
    min_distance: float | None = None,
    ncl_preset: str | None = None,
) -> CurlyVectorPlotSet:
    """
    Draw NCL-like curved vector glyphs for a 2D vector flow.

    Parameters
    ----------
    x, y : 1D/2D arrays
        Evenly spaced strictly increasing arrays to make a grid.  If 2D, all
        rows of *x* must be equal and all columns of *y* must be equal; i.e.,
        they must be as if generated by ``np.meshgrid(x_1d, y_1d)``.
        For non-uniform grids (e.g., vertical profiles), set allow_non_uniform_grid=True.
    u, v : 2D arrays
        *x* and *y*-velocities. The number of rows and columns must match
        the length of *y* and *x*, respectively.
    density : float or (float, float)
        Controls the closeness of streamlines. When ``density = 1``, the domain
        is divided into a 30x30 grid. *density* linearly scales this grid.
        Each cell in the grid can have, at most, one traversing streamline.
        For different densities in each direction, use a tuple
        (density_x, density_y).
    linewidth : float or 2D array
        The width of the streamlines. With a 2D array the line width can be
        varied across the grid. The array must have the same shape as *u*
        and *v*.
    linewidths : float or 2D array, optional
        Matplotlib ``quiver``-style alias for ``linewidth``.
    color : color or 2D array
        The streamline color. If given an array, its values are converted to
        colors using *cmap* and *norm*.  The array must have the same shape
        as *u* and *v*.
    c : color or 2D array, optional
        Matplotlib-style alias for ``color``.
    cmap, norm
        Data normalization and colormapping parameters for *color*; only used
        if *color* is an array of floats. See `~.Axes.imshow` for a detailed
        description.
    vmin, vmax : float, optional
        Lower and upper normalization bounds used when ``color``/``c`` is a
        scalar field and ``norm`` is omitted.
    alpha : float, optional
        Matplotlib artist alpha applied to both the curved shafts and the
        arrow heads.
    facecolor, edgecolor : color-like, optional
        Explicit arrow-head fill and edge colors, similar to
        ``matplotlib.pyplot.quiver``. These mainly affect the filled
        ``arrowstyle="-|>"`` head. When omitted, the resolved shaft color is
        reused. Open ``"->"`` heads remain line-based and therefore ignore
        ``facecolor``.
    facecolors, edgecolors : color-like, optional
        Matplotlib-style aliases for ``facecolor`` and ``edgecolor``.
    rasterized : bool, optional
        Whether to rasterize the generated curly-vector artists when exporting
        to vector formats such as PDF or SVG. This changes output rendering,
        not the underlying curly-vector algorithm.
    arrowsize : float
        Scaling factor for the arrow size.
    arrowstyle : str
        Supported arrow-head style. Use ``"->"`` for the open NCL-like line
        head or ``"-|>"`` for a filled triangular head.
    transform : Transform, optional
        Coordinate transformation for the plot. Defaults to axes.transData.
    zorder : float
        The zorder of the streamlines and arrows.
        Artists with lower zorder values are drawn first.
    start_points : (N, 2) array
        Coordinates of starting points for the streamlines in data coordinates
        (the same coordinates as the *x* and *y* arrays).
    integration_direction : {'forward', 'backward', 'both'}, default: 'both'
        Integrate the streamline in forward, backward or both directions.
    grains : int, default: 15
        Number of grains used in streamline integration.
    broken_streamlines : boolean, default: True
        If False, forces streamlines to continue until they
        leave the plot domain.  If True, they may be terminated if they
        come too close to another streamline.
    allow_non_uniform_grid : boolean, default: False
        If True, allows non-uniform grids like vertical profiles. The function
        will attempt to create a uniform interpolation grid for streamline calculation.
    anchor : {'tail', 'center', 'head'} or None, default: None
        Anchor point for the NCL-like curved-glyph renderer. If omitted, the anchor is
        inferred from ``integration_direction``: ``'forward'`` -> ``'tail'``,
        ``'backward'`` -> ``'head'``, ``'both'`` -> ``'center'``.
    pivot : {'tail', 'mid', 'middle', 'tip'} or None, default: None
        Matplotlib ``quiver``-style alias for ``anchor``. ``'mid'`` and
        ``'middle'`` map to ``'center'`` and ``'tip'`` maps to ``'head'``.
    ref_magnitude : float or None, default: None
        Reference magnitude used when mapping a
        physical vector magnitude to a display-space glyph length. If omitted,
        the maximum field magnitude is used.
    ref_length : float or None, default: None
        Reference glyph length as a fraction of the axes width for
        the NCL-like curved-glyph renderer. If omitted, a NCL-like default scaled by
        ``arrowsize`` is used.
    min_frac_length : float, default: 0.0
        Minimum glyph length as a fraction of the reference length for
        the NCL-like curved-glyph renderer.
    min_distance : float or None, default: None
        Minimum glyph-center spacing as a fraction of the axes width for
        the NCL-like curved-glyph renderer. If omitted, it is inferred from ``density``.
    ncl_preset : {None, 'profile'}, default: None
        Optional preset override for NCL-like glyph tuning. In most cases you
        can leave this as ``None``: regular lat-lon map grids keep the default
        map-style tuning, while non-uniform/profile-like grids are
        automatically promoted to the conservative ``'profile'`` preset.

    Returns
    -------
    CurlyVectorPlotSet
        Container object with attributes

        - ``lines``: `.LineCollection` of streamlines

        - ``arrows``: tuple of the actual filled arrow-head patches added to
          the axes. Open arrow styles use line segments only and therefore
          return an empty tuple.
    """
    color, linewidth, facecolor, edgecolor, vmin, vmax = _resolve_curly_style_aliases(
        color=color,
        c=c,
        linewidth=linewidth,
        linewidths=linewidths,
        facecolor=facecolor,
        facecolors=facecolors,
        edgecolor=edgecolor,
        edgecolors=edgecolors,
        norm=norm,
        vmin=vmin,
        vmax=vmax,
    )

    allow_non_uniform_grid, ncl_preset = _resolve_default_ncl_preset(
        x=x,
        y=y,
        allow_non_uniform_grid=allow_non_uniform_grid,
        ncl_preset=ncl_preset,
    )

    allow_non_uniform_grid, ref_magnitude, ref_length, min_distance, ncl_preset = (
        _apply_ncl_preset_defaults(
            ncl_preset=ncl_preset,
            allow_non_uniform_grid=allow_non_uniform_grid,
            ref_magnitude=ref_magnitude,
            ref_length=ref_length,
            min_distance=min_distance,
            u=u,
            v=v,
        )
    )
    arrowstyle = _normalize_supported_arrowstyle(arrowstyle)
    alpha = _normalize_artist_alpha(alpha)
    anchor = _resolve_curly_anchor_alias(anchor, pivot)

    if not allow_non_uniform_grid:
        x, y, u, v, color, linewidth = _normalize_regular_grid_orientation(
            x,
            y,
            u,
            v,
            color=color,
            linewidth=linewidth,
        )

    # Handle non-uniform grids by creating a uniform interpolation grid
    if allow_non_uniform_grid:
        expected_shape = np.shape(u)
        color_field, color_is_field = _coerce_matching_plot_field(color, expected_shape)
        linewidth_field, linewidth_is_field = _coerce_matching_plot_field(
            linewidth, expected_shape
        )
        if color_field is None and color_is_field:
            raise ValueError(
                "If 'color' is given, it must match the shape of the (x, y) grid"
            )
        if linewidth_field is None and linewidth_is_field:
            raise ValueError(
                "If 'linewidth' is given, it must match the shape of the (x, y) grid"
            )

        regridded = _regrid_non_uniform_vectors_to_uniform(
            x,
            y,
            u,
            v,
            *([field for field in (color_field, linewidth_field) if field is not None]),
        )
        x, y, u, v, *scalar_fields = regridded
        scalar_iter = iter(scalar_fields)
        if color_field is not None:
            color = next(scalar_iter)
        if linewidth_field is not None:
            linewidth = next(scalar_iter)

    return _curly_vector_ncl(
        axes,
        x,
        y,
        u,
        v,
        **_collect_named_kwargs(locals(), _CURLY_VECTOR_NCL_KWARG_NAMES),
    )


def _curly_vector_ncl(
    axes,
    x,
    y,
    u,
    v,
    density=1,
    linewidth=None,
    color=None,
    vmin=None,
    vmax=None,
    cmap=None,
    norm=None,
    alpha=None,
    facecolor=None,
    edgecolor=None,
    rasterized=None,
    arrowsize=1,
    arrowstyle="->",
    transform=None,
    zorder=None,
    start_points=None,
    integration_direction="both",
    grains=15,
    broken_streamlines=True,
    anchor=None,
    ref_magnitude=None,
    ref_length=None,
    min_frac_length=0.0,
    min_distance=None,
    allow_non_uniform_grid=False,
    ncl_preset=None,
):
    sample_grid_field_fn = partial(
        _native_helpers._call_native_sample_grid_field,
        _sample_grid_field_native,
    )
    native_open_builder = partial(
        _native_helpers._call_native_build_open_arrow_segments,
        _build_open_arrow_segments_native,
    )
    native_filled_builder = partial(
        _native_helpers._call_native_build_filled_arrow_polygons,
        _build_filled_arrow_polygons_native,
    )
    build_open_arrow_segments_batch_fn = partial(
        _artist_helpers._build_open_arrow_segments_batch,
        build_open_arrow_segments_batch_fn=native_open_builder,
    )
    build_filled_arrow_polygons_batch_fn = partial(
        _artist_helpers._build_filled_arrow_polygons_batch,
        build_filled_arrow_polygons_batch_fn=native_filled_builder,
        display_points_to_data_fn=_display_points_to_data,
    )

    def sample_grid_field_array(grid, field, points):
        points = np.asarray(points, dtype=float)
        if points.ndim == 1:
            points = points[np.newaxis, :]
        if len(points) == 0:
            return np.empty(0, dtype=float)
        return _native_helpers._call_native_sample_grid_field_array(
            _sample_grid_field_array_native,
            grid,
            field,
            points,
            (len(points),),
        )

    thin_mapped_candidates = partial(
        _native_helpers._call_native_thin_ncl_mapped_candidates,
        _thin_ncl_mapped_candidates_native,
    )
    thin_display_candidates = partial(
        _native_helpers._call_native_thin_ncl_display_candidates,
        _thin_ncl_display_candidates_native,
    )
    select_ncl_centers_fn = partial(
        _vector_engine._select_ncl_centers,
        sample_grid_field_array=sample_grid_field_array,
        thin_ncl_mapped_candidates=thin_mapped_candidates,
        thin_ncl_display_candidates=thin_display_candidates,
    )
    build_ncl_curve_fn = partial(
        _vector_engine._build_ncl_curve,
        trace_ncl_curve_fn=_trace_ncl_curve_with_display,
        evaluate_ncl_display_curve_fn=lambda curve, transform, viewport=None, display_curve=None: (
            (
                display_curve,
                False,
            )
            if display_curve is not None
            and len(curve) >= 2
            and _validate_display_curve(display_curve, viewport)
            else (None, False)
        ),
    )
    return _vector_engine._curly_vector_ncl_impl(
        axes=axes,
        x=x,
        y=y,
        u=u,
        v=v,
        density=density,
        linewidth=linewidth,
        color=color,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        norm=norm,
        alpha=alpha,
        facecolor=facecolor,
        edgecolor=edgecolor,
        rasterized=rasterized,
        arrowsize=arrowsize,
        arrowstyle=arrowstyle,
        transform=transform,
        zorder=zorder,
        start_points=start_points,
        integration_direction=integration_direction,
        grains=grains,
        broken_streamlines=broken_streamlines,
        anchor=anchor,
        ref_magnitude=ref_magnitude,
        ref_length=ref_length,
        min_frac_length=min_frac_length,
        min_distance=min_distance,
        allow_non_uniform_grid=allow_non_uniform_grid,
        ncl_preset=ncl_preset,
        grid_cls=Grid,
        prepare_ncl_display_sampler_fn=_prepare_ncl_display_sampler,
        prepare_ncl_native_trace_context_fn=_prepare_ncl_native_trace_context,
        select_ncl_centers_fn=select_ncl_centers_fn,
        build_ncl_curve_fn=build_ncl_curve_fn,
        sample_grid_field_fn=sample_grid_field_fn,
        build_open_arrow_segments_batch_fn=build_open_arrow_segments_batch_fn,
        build_filled_arrow_polygons_batch_fn=build_filled_arrow_polygons_batch_fn,
        display_points_to_data_fn=_display_points_to_data,
        result_cls=CurlyVectorPlotSet,
    )


def _prepare_ncl_native_trace_context(grid, u, v, viewport, display_sampler):
    if display_sampler is None:
        return None
    return _NCLNativeTraceContext(
        grid=grid,
        u=u,
        v=v,
        viewport=viewport,
        display_sampler=display_sampler,
    )


def _trace_ncl_curve_with_display(
    start_point,
    total_length_px,
    anchor,
    grid,
    u,
    v,
    transform,
    step_px,
    speed_scale,
    viewport,
    display_sampler=None,
    native_trace_context=None,
):
    del transform
    if total_length_px <= 0:
        return None

    if anchor == "center":
        backward = _trace_ncl_direction_with_display(
            start_point,
            total_length_px / 2.0,
            -1.0,
            grid,
            u,
            v,
            step_px,
            speed_scale,
            viewport,
            display_sampler=display_sampler,
            native_trace_context=native_trace_context,
        )
        forward = _trace_ncl_direction_with_display(
            start_point,
            total_length_px / 2.0,
            1.0,
            grid,
            u,
            v,
            step_px,
            speed_scale,
            viewport,
            display_sampler=display_sampler,
            native_trace_context=native_trace_context,
        )
        if backward is None and forward is None:
            return None
        if backward is None:
            return forward
        if forward is None:
            curve, display_curve = backward
            return curve[::-1], display_curve[::-1]

        backward_curve, backward_display = backward
        forward_curve, forward_display = forward
        return (
            np.vstack([backward_curve[::-1], forward_curve[1:]]),
            np.vstack([backward_display[::-1], forward_display[1:]]),
        )

    if anchor == "tail":
        return _trace_ncl_direction_with_display(
            start_point,
            total_length_px,
            1.0,
            grid,
            u,
            v,
            step_px,
            speed_scale,
            viewport,
            display_sampler=display_sampler,
            native_trace_context=native_trace_context,
        )

    backward = _trace_ncl_direction_with_display(
        start_point,
        total_length_px,
        -1.0,
        grid,
        u,
        v,
        step_px,
        speed_scale,
        viewport,
        display_sampler=display_sampler,
        native_trace_context=native_trace_context,
    )
    if backward is None:
        return None
    curve, display_curve = backward
    return curve[::-1], display_curve[::-1]


def _trace_ncl_direction_with_display(
    start_point,
    max_length_px,
    direction_sign,
    grid,
    u,
    v,
    step_px,
    speed_scale,
    viewport,
    display_sampler=None,
    native_trace_context=None,
):
    start_point = np.asarray(start_point, dtype=float)
    if native_trace_context is None and display_sampler is not None:
        native_trace_context = _prepare_ncl_native_trace_context(
            grid=grid,
            u=u,
            v=v,
            viewport=viewport,
            display_sampler=display_sampler,
        )
    return _native_helpers._call_native_trace_ncl_direction_with_display(
        _trace_ncl_direction_native,
        native_trace_context,
        start_point,
        max_length_px,
        direction_sign,
        step_px,
        speed_scale,
    )


def _validate_display_curve(display_curve, viewport):
    return _native_helpers._call_native_validate_display_curve(
        _validate_display_curve_native,
        display_curve,
        viewport,
    )


def _prepare_curly_vector_dataset_inputs(
    ax,
    x,
    y,
    u,
    v,
    transform,
    regrid_shape,
    curvilinear_regrid_shape,
    target_extent,
    color,
    linewidth,
    density,
):
    return _entry_adapter._prepare_curly_vector_dataset_inputs_impl(
        ax,
        x,
        y,
        u,
        v,
        transform,
        regrid_shape,
        curvilinear_regrid_shape,
        target_extent,
        color,
        linewidth,
        density,
        is_cartopy_crs_like_fn=_is_cartopy_crs_like,
        maybe_as_scalar_field_fn=_maybe_as_scalar_field,
        is_curvilinear_grid_fn=_is_curvilinear_grid,
        normalize_regrid_shape_fn=_normalize_regrid_shape,
        default_curvilinear_regrid_shape_fn=_default_curvilinear_regrid_shape,
        regrid_curvilinear_vectors_fn=_regrid_curvilinear_vectors,
        regrid_cartopy_vectors_fn=_regrid_cartopy_vectors,
        default_cartopy_target_extent_fn=_default_cartopy_target_extent,
        extract_regular_grid_from_regridded_vectors_fn=_extract_regular_grid_from_regridded_vectors,
    )


def _curly_vector_from_dataset(
    ds: xr.Dataset,
    x: Hashable,
    y: Hashable,
    u: Hashable,
    v: Hashable,
    ax: Axes | None = None,
    density: Any = 1,
    linewidth: Any = None,
    linewidths: Any = None,
    color: Any = None,
    c: Any = None,
    cmap: Any = None,
    norm: Any = None,
    vmin: float | None = None,
    vmax: float | None = None,
    alpha: float | None = None,
    facecolor: Any = None,
    facecolors: Any = None,
    edgecolor: Any = None,
    edgecolors: Any = None,
    rasterized: bool | None = None,
    arrowsize=1,
    arrowstyle="->",
    transform: Any = None,
    zorder: float | None = None,
    start_points: Any = None,
    integration_direction="both",
    grains=15,
    broken_streamlines=True,
    anchor: str | None = None,
    pivot: str | None = None,
    ref_magnitude: float | None = None,
    ref_length: float | None = None,
    min_frac_length=0.0,
    min_distance: float | None = None,
    ncl_preset: str | None = None,
    regrid_shape: Any = None,
    curvilinear_regrid_shape: Any = None,
    target_extent: Any = None,
    isel: Any = None,
) -> CurlyVectorPlotSet:
    """Plot NCL-like curly vectors from an xarray dataset."""
    return _entry_adapter._curly_vector_from_dataset_impl(
        ds,
        x,
        y,
        u,
        v,
        ax=ax,
        density=density,
        linewidth=linewidth,
        linewidths=linewidths,
        color=color,
        c=c,
        cmap=cmap,
        norm=norm,
        vmin=vmin,
        vmax=vmax,
        alpha=alpha,
        facecolor=facecolor,
        facecolors=facecolors,
        edgecolor=edgecolor,
        edgecolors=edgecolors,
        rasterized=rasterized,
        arrowsize=arrowsize,
        arrowstyle=arrowstyle,
        transform=transform,
        zorder=zorder,
        start_points=start_points,
        integration_direction=integration_direction,
        grains=grains,
        broken_streamlines=broken_streamlines,
        anchor=anchor,
        pivot=pivot,
        ref_magnitude=ref_magnitude,
        ref_length=ref_length,
        min_frac_length=min_frac_length,
        min_distance=min_distance,
        ncl_preset=ncl_preset,
        regrid_shape=regrid_shape,
        curvilinear_regrid_shape=curvilinear_regrid_shape,
        target_extent=target_extent,
        isel=isel,
        resolve_curly_style_aliases_fn=_resolve_curly_style_aliases,
        extract_curly_vector_dataset_source_fn=_extract_curly_vector_dataset_source,
        prepare_dataset_style_field_fn=_prepare_dataset_style_field,
        gca_fn=plt.gca,
        prepare_curly_vector_dataset_inputs_fn=_prepare_curly_vector_dataset_inputs,
        collect_named_kwargs_fn=_collect_named_kwargs,
        array_curly_vector_kwarg_names=_ARRAY_CURLY_VECTOR_KWARG_NAMES,
        array_curly_vector_fn=_array_curly_vector,
    )


def _curly_vector_from_arrays(
    ax: Axes,
    x: Any,
    y: Any,
    u: Any,
    v: Any,
    density: Any = 1,
    linewidth: Any = None,
    linewidths: Any = None,
    color: Any = None,
    c: Any = None,
    cmap: Any = None,
    norm: Any = None,
    vmin: float | None = None,
    vmax: float | None = None,
    alpha: float | None = None,
    facecolor: Any = None,
    facecolors: Any = None,
    edgecolor: Any = None,
    edgecolors: Any = None,
    rasterized: bool | None = None,
    arrowsize=1,
    arrowstyle="->",
    transform: Any = None,
    zorder: float | None = None,
    start_points: Any = None,
    integration_direction="both",
    grains=15,
    broken_streamlines=True,
    anchor: str | None = None,
    pivot: str | None = None,
    ref_magnitude: float | None = None,
    ref_length: float | None = None,
    min_frac_length=0.0,
    min_distance: float | None = None,
    ncl_preset: str | None = None,
    regrid_shape: Any = None,
    curvilinear_regrid_shape: Any = None,
    target_extent: Any = None,
) -> CurlyVectorPlotSet:
    """Array-input adapter that preserves Cartopy and curvilinear support."""
    return _entry_adapter._curly_vector_from_arrays_impl(
        ax,
        x,
        y,
        u,
        v,
        density=density,
        linewidth=linewidth,
        linewidths=linewidths,
        color=color,
        c=c,
        cmap=cmap,
        norm=norm,
        vmin=vmin,
        vmax=vmax,
        alpha=alpha,
        facecolor=facecolor,
        facecolors=facecolors,
        edgecolor=edgecolor,
        edgecolors=edgecolors,
        rasterized=rasterized,
        arrowsize=arrowsize,
        arrowstyle=arrowstyle,
        transform=transform,
        zorder=zorder,
        start_points=start_points,
        integration_direction=integration_direction,
        grains=grains,
        broken_streamlines=broken_streamlines,
        anchor=anchor,
        pivot=pivot,
        ref_magnitude=ref_magnitude,
        ref_length=ref_length,
        min_frac_length=min_frac_length,
        min_distance=min_distance,
        ncl_preset=ncl_preset,
        regrid_shape=regrid_shape,
        curvilinear_regrid_shape=curvilinear_regrid_shape,
        target_extent=target_extent,
        resolve_curly_style_aliases_fn=_resolve_curly_style_aliases,
        asarray_fn=np.asarray,
        prepare_curly_vector_dataset_inputs_fn=_prepare_curly_vector_dataset_inputs,
        collect_named_kwargs_fn=_collect_named_kwargs,
        array_curly_vector_kwarg_names=_ARRAY_CURLY_VECTOR_KWARG_NAMES,
        array_curly_vector_fn=_array_curly_vector,
    )


[docs] def curly_vector(*args: Any, **kwargs: Any) -> CurlyVectorPlotSet: """Plot NCL-like curly vectors from arrays or an xarray dataset. Parameters ---------- ds : xarray.Dataset, optional Dataset source for the dataset-style call form. When ``ds`` is used as the first positional argument, ``x``, ``y``, ``u``, and ``v`` must be the corresponding coordinate and variable names inside that dataset. ax : matplotlib.axes.Axes, optional Target axes for the vector plot. If omitted, ``matplotlib.pyplot.gca()`` is used. x, y : array-like or hashable Coordinate definition for the vector field. For array-style calls, these may be 1D or 2D coordinate arrays. If 2D, they should describe a meshgrid-like layout matching the vector field. For dataset-style calls, they are the coordinate names inside ``ds``. u, v : array-like or hashable Vector components. For array-style calls, these are 2D numeric arrays aligned with the supplied coordinates. For dataset-style calls, they are variable names inside ``ds``. density : float or tuple of float, optional Controls the closeness of the rendered curly vectors. As in the low-level engine, ``density=1`` corresponds to the default NCL-like sampling density, and a tuple ``(density_x, density_y)`` can be used for anisotropic spacing. linewidth : float or 2D array, optional Width of the curved vector shafts. A field array must match the vector grid shape. linewidths : float or 2D array, optional Matplotlib ``quiver``-style alias for ``linewidth``. color : color-like or 2D array, optional Shaft color. If a 2D scalar field is supplied, Skyborn maps it through ``cmap`` and ``norm``. c : color-like or 2D array, optional Matplotlib-style alias for ``color``. cmap, norm Colormap and normalization controls used when ``color``/``c`` is a scalar field. vmin, vmax : float, optional Lower and upper normalization bounds used when ``color``/``c`` is a scalar field and ``norm`` is omitted. alpha : float, optional Matplotlib artist alpha applied to both the curved shafts and the arrow heads. facecolor, edgecolor : color-like, optional Explicit arrow-head fill and edge colors. When omitted, the resolved shaft color is reused. facecolors, edgecolors : color-like, optional Matplotlib-style aliases for ``facecolor`` and ``edgecolor``. rasterized : bool, optional Whether to rasterize the generated curly-vector artists when exporting to vector formats such as PDF or SVG. arrowsize : float, optional Scaling factor for the arrow size. arrowstyle : str, optional Supported arrow-head style. Use ``"->"`` for the open NCL-like line head or ``"-|>"`` for a filled triangular head. transform : optional Coordinate transformation for the plot. Standard Matplotlib transforms are forwarded directly. Cartopy CRS-like objects are normalized internally. zorder : float, optional Z-order of the curved shafts and arrow heads. start_points : (N, 2) array-like, optional Explicit seed points for the curly-vector glyphs in data coordinates. integration_direction : {'forward', 'backward', 'both'}, optional Integrate the glyph shape in the forward, backward, or both directions. grains : int, optional Number of grains used during streamline-style integration. broken_streamlines : bool, optional If ``False``, forces traces to continue until they leave the domain. If ``True``, traces may terminate early when they come too close to another glyph. anchor : {'tail', 'center', 'head'} or None, optional Anchor point for the NCL-like curved-glyph renderer. If omitted, the anchor is inferred from ``integration_direction``. pivot : {'tail', 'mid', 'middle', 'tip'} or None, optional Matplotlib ``quiver``-style alias for ``anchor``. ref_magnitude : float, optional Reference magnitude used when mapping a physical vector magnitude to a display-space glyph length. If omitted, the maximum field magnitude is used. ref_length : float, optional Reference glyph length as a fraction of the axes width. If omitted, a NCL-like default scaled by ``arrowsize`` is used. min_frac_length : float, optional Minimum glyph length as a fraction of the reference length. min_distance : float, optional Minimum glyph-center spacing as a fraction of the axes width. If omitted, it is inferred from ``density``. ncl_preset : {None, 'profile'}, optional Optional preset override for NCL-like glyph tuning. In most cases you can leave this as ``None`` and let Skyborn choose the default behavior. regrid_shape : int or (int, int), optional Target shape for projection-aware Cartopy regridding. This requires a Cartopy CRS-like ``transform`` and a GeoAxes with a projection. curvilinear_regrid_shape : int or (int, int), optional Target shape used when the source coordinates describe a curvilinear grid. If omitted, Skyborn infers a conservative default from the source grid and ``density``. target_extent : tuple of float, optional Explicit ``(xmin, xmax, ymin, ymax)`` target extent used during projection-aware regridding. isel : mapping or tuple, optional Dataset-only selection forwarded before extracting ``x``, ``y``, ``u``, and ``v`` from ``ds``. **kwargs Additional style and artist options supported by the public wrapper. Returns ------- CurlyVectorPlotSet Container object with attributes - ``lines``: `.LineCollection` of the curved vector shafts. - ``arrows``: tuple of arrow-head artists added to the axes. Open arrow styles use line segments only and may therefore return an empty tuple. Notes ----- Supported call styles - ``curly_vector(ax, x, y, u, v, ...)`` - ``curly_vector(x, y, u, v, ..., ax=ax)`` - ``curly_vector(x, y, u, v, ...)`` - ``curly_vector(ds, x="lon", y="lat", u="u", v="v", ax=ax, ...)`` """ if not args: raise TypeError( "curly_vector() expects either (ax, x, y, u, v, ...) or " "(ds, x='...', y='...', u='...', v='...', ...)" ) first = args[0] if _looks_like_axes(first): return _curly_vector_from_arrays(*args, **kwargs) if isinstance(first, xr.Dataset): return _curly_vector_from_dataset(*args, **kwargs) if len(args) >= 4: ax = kwargs.pop("ax", None) if ax is None: ax = plt.gca() return _curly_vector_from_arrays(ax, *args, **kwargs) raise TypeError( "Unsupported arguments for curly_vector(). Expected either " "(ax, x, y, u, v, ...), (x, y, u, v, ...), or " "(ds, x='...', y='...', u='...', v='...', ...)." )
[docs] def curly_vector_key( *args: Any, **kwargs: Any, ) -> CurlyVectorKey: """Add an NCL-like reference-vector annotation to axes. Parameters ---------- ax : matplotlib.axes.Axes, optional Target axes for the reference key. If omitted, ``matplotlib.pyplot.gca()`` is used. curly_vector_set : CurlyVectorPlotSet Plot result returned by :func:`curly_vector`. The key reads glyph scale and reference-length information from this object. U : float, optional Reference vector magnitude to display in the key. Defaults to ``2.0``. units : str, optional Units label displayed with the reference magnitude. Defaults to ``"m/s"``. label : str, optional Optional custom label text for the annotation. description : str, optional Optional secondary description line shown with the reference key. width, height : float, optional Size of the reference-key box in axes coordinates. loc : {"lower left", "lower right", "upper left", "upper right"}, optional Preset box location used when explicit ``x``/``y`` axes coordinates are not provided. x, y : float, optional Explicit box anchor position in axes coordinates. These two parameters must be provided together and override ``loc``. labelpos : {"N", "S", "E", "W"}, optional Relative label placement around the reference arrow. max_arrow_length : float, optional Maximum arrow length used by the fallback reference-arrow geometry. arrow_props, patch_props, text_props : dict, optional Styling dictionaries forwarded to the low-level arrow, frame, and text artists. padding : float, optional Internal padding between the frame and the reference-arrow/text layout. margin : float, optional Margin from the axes edge when ``loc`` is used. reference_speed : float, optional Fallback reference speed used when the plot set cannot provide a usable scale mapping. center_label : bool, optional Whether to center the main label rather than using the default north/south/east/west placement logic. frameon : bool, optional Whether to draw the surrounding reference-key frame. show_description : bool, optional Whether to render the description text when one is available. **kwargs Additional keyword arguments forwarded to :class:`CurlyVectorKey`. Returns ------- CurlyVectorKey The reference-key artist added to the axes. Notes ----- Supported call styles - ``curly_vector_key(ax, curly_vector_set, U=2.0, ...)`` - ``curly_vector_key(curly_vector_set, U=2.0, ax=ax, ...)`` - ``curly_vector_key(curly_vector_set, ...)`` """ if not args and "curly_vector_set" not in kwargs: raise TypeError( "curly_vector_key() expects either (ax, curly_vector_set, ...) or " "(curly_vector_set, ...)" ) ax_kwarg = kwargs.pop("ax", None) units = kwargs.pop("units", "m/s") label = kwargs.pop("label", None) loc = kwargs.pop("loc", "lower right") labelpos = kwargs.pop("labelpos", "N") remaining_args = list(args) if remaining_args and _looks_like_axes(remaining_args[0]): ax = remaining_args.pop(0) else: ax = ax_kwarg if ax_kwarg is not None else plt.gca() if remaining_args: curly_vector_set = remaining_args.pop(0) elif "curly_vector_set" in kwargs: curly_vector_set = kwargs.pop("curly_vector_set") else: raise TypeError("curly_vector_key() missing required curly_vector_set argument") if remaining_args: U = float(remaining_args.pop(0)) else: U = float(kwargs.pop("U", 2.0)) if remaining_args: units = str(remaining_args.pop(0)) if remaining_args: raise TypeError("curly_vector_key() received too many positional arguments") return CurlyVectorKey( ax=ax, curly_vector_set=curly_vector_set, U=U, units=units, label=label, loc=loc, labelpos=labelpos, **kwargs, )