Source code for numpydantic.ndarray

"""
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization

.. note::

    This module should *only* have the :class:`.NDArray` class in it, because the
    type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise,
    type checkers will complain about using any helper functions elsewhere -
    those all belong in :mod:`numpydantic.schema` .

    Keeping with nptyping's style, NDArrayMeta is in this module even if it's
    excluded from the type stub.

"""

from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Protocol,
    TypeVar,
    _ProtocolMeta,
    get_args,
    get_origin,
    runtime_checkable,
)

import numpy as np
from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema

from numpydantic.dtype import DType
from numpydantic.exceptions import InterfaceError
from numpydantic.interface import Interface
from numpydantic.maps import python_to_nptyping
from numpydantic.schema import (
    get_validate_interface,
    make_json_schema,
)
from numpydantic.serialization import jsonize_array
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
from numpydantic.vendor.nptyping.structure import Structure
from numpydantic.vendor.nptyping.structure_expression import check_type_names
from numpydantic.vendor.nptyping.typing_ import (
    dtype_per_name,
)

if TYPE_CHECKING:  # pragma: no cover
    from pydantic._internal._schema_generation_shared import (
        CallbackGetCoreSchemaHandler,
    )

    from numpydantic import Shape


def _is_literal_like(item: Any) -> bool:
    """
    Changes from nptyping:
    - doesn't just ducktype for literal but actually, yno, checks for being literal
    """
    return get_origin(item) is Literal


def _get_shape(dtype_candidate: Any) -> Shape:
    """
    Override of base method to use our local definition of shape
    """
    from numpydantic.validation.shape import Shape

    if dtype_candidate is Any or dtype_candidate is Shape:
        shape = Any
    elif issubclass(dtype_candidate, Shape):
        shape = dtype_candidate
    elif _is_literal_like(dtype_candidate):
        shape_expression = dtype_candidate.__args__[0]
        shape = Shape[shape_expression]
    elif get_origin(dtype_candidate) is tuple:
        shape = Shape(*get_args(dtype_candidate))
    else:
        raise InvalidArgumentsError(
            f"Unexpected argument '{dtype_candidate}', expecting"
            " Shape[<ShapeExpression>]"
            " or Literal[<ShapeExpression>]"
            " or typing.Any."
        )
    return shape


def _get_dtype(dtype_candidate: Any) -> DType:
    """
    Override of base _get_dtype method to allow for compound tuple types
    """
    if dtype_candidate in python_to_nptyping:
        dtype_candidate = python_to_nptyping[dtype_candidate]
    is_dtype = isinstance(dtype_candidate, type) and issubclass(
        dtype_candidate, np.generic
    )

    if dtype_candidate is Any:
        dtype = Any
    elif is_dtype or is_union(dtype_candidate):
        dtype = dtype_candidate
    elif issubclass(dtype_candidate, Structure):  # pragma: no cover
        dtype = dtype_candidate
        check_type_names(dtype, dtype_per_name)
    elif _is_literal_like(dtype_candidate):  # pragma: no cover
        structure_expression = dtype_candidate.__args__[0]
        dtype = Structure[structure_expression]
        check_type_names(dtype, dtype_per_name)
    elif isinstance(dtype_candidate, tuple):  # pragma: no cover
        dtype = tuple([_get_dtype(dt) for dt in dtype_candidate])
    else:
        # arbitrary dtype - allow failure elsewhere :)
        dtype = dtype_candidate

    return dtype


TShape = TypeVar("TShape")
TDType = TypeVar("TDType")


[docs] class NDArrayMeta(_ProtocolMeta): """ Metaclass to provide class-level methods to NDArray protocol without suggesting they are part of the protocol definition. """ __args__: tuple[ShapeType, DtypeType] = (Any, Any)
[docs] def __call__(cls, val: NDArrayType) -> NDArrayType: """Call ndarray as a validator function""" return get_validate_interface(cls.__args__[0], cls.__args__[1])(val)
[docs] def __instancecheck__(self, instance: Any): """ Extended type checking that determines whether 1) the ``type`` of the given instance is one of those in :meth:`.Interface.input_types` but also 2) it satisfies the constraints set on the :class:`.NDArray` annotation Args: instance (:class:`typing.Any`): Thing to check! Returns: bool: ``True`` if matches constraints, ``False`` otherwise. """ shape, dtype = self.__args__ try: interface_cls = Interface.match(instance, fast=True) interface = interface_cls(shape, dtype) _ = interface.validate(instance) return True except InterfaceError: return False
def _dtype_to_str(cls, dtype: Any) -> str: if dtype is Any: result = "Any" elif issubclass(dtype, Structure): result = str(dtype) elif isinstance(dtype, tuple): result = ", ".join([str(dt) for dt in dtype]) else: result = str(dtype) return result def __getitem__(cls, args: type[Any] | tuple[type[Any], type[Any]]): if not isinstance(args, tuple) or (isinstance(args, tuple) and len(args) == 1): # just shape passed shape = args if not isinstance(args, TypeVar) else Any dtype = Any else: shape = args[0] if not isinstance(args[0], TypeVar) else Any dtype = args[1] if not isinstance(args[0], TypeVar) else Any shape = _get_shape(shape) dtype = _get_dtype(dtype) return type(cls.__name__, (cls,), {**cls.__dict__, "__args__": (shape, dtype)}) def __get_pydantic_core_schema__( cls, _source_type: NDArray | Any, _handler: CallbackGetCoreSchemaHandler, ) -> core_schema.CoreSchema: shape, dtype = cls.__args__ shape: ShapeType dtype: DtypeType serialization = core_schema.plain_serializer_function_ser_schema( jsonize_array, when_used="json", info_arg=True ) # make core schema for json schema, store it and any model definitions # so that we can use them when rendering json schema json_schema = make_json_schema(shape, dtype, _handler) if ( not hasattr(_source_type, "__class__") or _source_type.__class__ is not NDArrayMeta ): if hasattr(_source_type, "proxy_for"): interface: type[Interface] = _source_type.proxy_for() isinstance_schema = core_schema.union_schema( [ core_schema.is_instance_schema(itype) for itype in list(interface.input_types) + [_source_type] ] ) else: isinstance_schema = core_schema.is_instance_schema(_source_type) return core_schema.chain_schema( [ isinstance_schema, core_schema.with_info_plain_validator_function( get_validate_interface(shape, dtype) ), ], serialization=serialization, metadata=json_schema, ) else: return core_schema.with_info_plain_validator_function( get_validate_interface(shape, dtype), serialization=serialization, metadata=json_schema, ) def __get_pydantic_json_schema__( cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> core_schema.JsonSchema: shape, dtype = cls.__args__ json_schema = handler(schema["metadata"]) json_schema = handler.resolve_ref_schema(json_schema) if ( not isinstance(dtype, tuple) and dtype.__module__ not in ( "builtins", "typing", "types", ) and hasattr(dtype, "__name__") ): json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__]) return json_schema
[docs] @runtime_checkable class NDArray(Protocol[TShape, TDType], metaclass=NDArrayMeta): """ Constrained array type allowing npytyping syntax for dtype and shape validation and serialization. This class is not intended to be instantiable, and support for static type checking is limited, it implements the ``__get_pydantic_core_schema__`` method to invoke the relevant :ref:`interface <Interfaces>` for validation and serialization. It is callable, however, which validates and attempts to coerce input to a supported array type. There is no such thing as an "NDArray instance," but one can think of it as a validating passthrough callable. References: - https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types """ shape = np.ndarray.shape dtype = np.ndarray.dtype def __getitem__(self: Any, key: Any) -> Any: ... def __setitem__(self: Any, key: Any, value: Any) -> Any: ...