Source code for numpydantic.ndarray
"""
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
.. note::
This module should *only* have the :class:`.NDArray` class in it, because the
type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise,
type checkers will complain about using any helper functions elsewhere -
those all belong in :mod:`numpydantic.schema` .
Keeping with nptyping's style, NDArrayMeta is in this module even if it's
excluded from the type stub.
"""
from typing import TYPE_CHECKING, Any, Literal, Union, get_origin
import numpy as np
from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema
from numpydantic.dtype import DType
from numpydantic.exceptions import InterfaceError
from numpydantic.interface import Interface
from numpydantic.maps import python_to_nptyping
from numpydantic.schema import (
get_validate_interface,
make_json_schema,
)
from numpydantic.serialization import jsonize_array
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
from numpydantic.vendor.nptyping.structure import Structure
from numpydantic.vendor.nptyping.structure_expression import check_type_names
from numpydantic.vendor.nptyping.typing_ import (
dtype_per_name,
)
if TYPE_CHECKING: # pragma: no cover
from pydantic._internal._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from numpydantic import Shape
from numpydantic.vendor.nptyping.base_meta_classes import SubscriptableMeta
[docs]
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
"""
Hooking into nptyping's array metaclass to override methods pending
completion of the transition away from nptyping
"""
if TYPE_CHECKING: # pragma: no cover
__getitem__ = SubscriptableMeta.__getitem__
[docs]
def __call__(cls, val: NDArrayType) -> NDArrayType:
"""Call ndarray as a validator function"""
return get_validate_interface(cls.__args__[0], cls.__args__[1])(val)
[docs]
def __instancecheck__(self, instance: Any):
"""
Extended type checking that determines whether
1) the ``type`` of the given instance is one of those in
:meth:`.Interface.input_types`
but also
2) it satisfies the constraints set on the :class:`.NDArray` annotation
Args:
instance (:class:`typing.Any`): Thing to check!
Returns:
bool: ``True`` if matches constraints, ``False`` otherwise.
"""
shape, dtype = self.__args__
try:
interface_cls = Interface.match(instance, fast=True)
interface = interface_cls(shape, dtype)
_ = interface.validate(instance)
return True
except InterfaceError:
return False
def _is_literal_like(cls, item: Any) -> bool:
"""
Changes from nptyping:
- doesn't just ducktype for literal but actually, yno, checks for being literal
"""
return get_origin(item) is Literal
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
"""
Override of base method to use our local definition of shape
"""
from numpydantic.validation.shape import Shape
if dtype_candidate is Any or dtype_candidate is Shape:
shape = Any
elif issubclass(dtype_candidate, Shape):
shape = dtype_candidate
elif cls._is_literal_like(dtype_candidate):
shape_expression = dtype_candidate.__args__[0]
shape = Shape[shape_expression]
else:
raise InvalidArgumentsError(
f"Unexpected argument '{dtype_candidate}', expecting"
" Shape[<ShapeExpression>]"
" or Literal[<ShapeExpression>]"
" or typing.Any."
)
return shape
def _get_dtype(cls, dtype_candidate: Any) -> DType:
"""
Override of base _get_dtype method to allow for compound tuple types
"""
if dtype_candidate in python_to_nptyping:
dtype_candidate = python_to_nptyping[dtype_candidate]
is_dtype = isinstance(dtype_candidate, type) and issubclass(
dtype_candidate, np.generic
)
if dtype_candidate is Any:
dtype = Any
elif is_dtype or is_union(dtype_candidate):
dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure): # pragma: no cover
dtype = dtype_candidate
check_type_names(dtype, dtype_per_name)
elif cls._is_literal_like(dtype_candidate): # pragma: no cover
structure_expression = dtype_candidate.__args__[0]
dtype = Structure[structure_expression]
check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple): # pragma: no cover
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else:
# arbitrary dtype - allow failure elsewhere :)
dtype = dtype_candidate
return dtype
def _dtype_to_str(cls, dtype: Any) -> str:
if dtype is Any:
result = "Any"
elif issubclass(dtype, Structure):
result = str(dtype)
elif isinstance(dtype, tuple):
result = ", ".join([str(dt) for dt in dtype])
else:
result = str(dtype)
return result
[docs]
class NDArray(NPTypingType, metaclass=NDArrayMeta):
"""
Constrained array type allowing npytyping syntax for dtype and shape validation
and serialization.
This class is not intended to be instantiable, and support for static type
checking is limited,
it implements the ``__get_pydantic_core_schema__`` method to invoke
the relevant :ref:`interface <Interfaces>` for validation and serialization.
It is callable, however, which validates and attempts to coerce input to a
supported array type.
There is no such thing as an "NDArray instance," but one can think of it
as a validating passthrough callable.
References:
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
"""
__args__: tuple[ShapeType, DtypeType] = (Any, Any)
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Union["NDArray", Any],
_handler: "CallbackGetCoreSchemaHandler",
) -> core_schema.CoreSchema:
shape, dtype = cls.__args__
shape: ShapeType
dtype: DtypeType
serialization = core_schema.plain_serializer_function_ser_schema(
jsonize_array, when_used="json", info_arg=True
)
# make core schema for json schema, store it and any model definitions
# so that we can use them when rendering json schema
json_schema = make_json_schema(shape, dtype, _handler)
if (
not hasattr(_source_type, "__class__")
or _source_type.__class__ is not NDArrayMeta
):
if hasattr(_source_type, "proxy_for"):
interface: type[Interface] = _source_type.proxy_for()
isinstance_schema = core_schema.union_schema(
[
core_schema.is_instance_schema(itype)
for itype in list(interface.input_types) + [_source_type]
]
)
else:
isinstance_schema = core_schema.is_instance_schema(_source_type)
return core_schema.chain_schema(
[
isinstance_schema,
core_schema.with_info_plain_validator_function(
get_validate_interface(shape, dtype)
),
],
serialization=serialization,
metadata=json_schema,
)
else:
return core_schema.with_info_plain_validator_function(
get_validate_interface(shape, dtype),
serialization=serialization,
metadata=json_schema,
)
@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> core_schema.JsonSchema:
shape, dtype = cls.__args__
json_schema = handler(schema["metadata"])
json_schema = handler.resolve_ref_schema(json_schema)
if (
not isinstance(dtype, tuple)
and dtype.__module__
not in (
"builtins",
"typing",
"types",
)
and hasattr(dtype, "__name__")
):
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
return json_schema