Source code for numpydantic.validation.dtype

"""
Helper functions for validation of dtype.

For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
"""

from types import UnionType
from typing import Any, Union, get_args, get_origin

import numpy as np

from numpydantic.types import DtypeType


[docs] def validate_dtype(dtype: Any, target: DtypeType) -> bool: """ Validate a dtype against the target dtype. If `dtype` or `target` are `Any`, validation passes trivially. The `dtype` may be `Any` when the dtype can't be determined, but failure to determine dtype shouldn't be fatal (e.g. BaseModel dtypes for empty arrays). Args: dtype: The dtype to validate target (:class:`.DtypeType`): The target dtype Returns: bool: ``True`` if valid, ``False`` otherwise """ if target is Any or dtype is Any: return True if isinstance(target, tuple): valid = any(validate_dtype(dtype, target_dt) for target_dt in target) elif is_union(target): valid = any( [validate_dtype(dtype, target_dt) for target_dt in get_args(target)] ) elif target is np.str_: valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( np.str_, str, ) else: # try to match as any subclass, if target is a class try: valid = issubclass(dtype, target) except TypeError: # error expected if dtype or target is not a class # main type check - directly test dtype identity valid = dtype == target or getattr(dtype, "type", None) == target return valid
[docs] def is_union(dtype: DtypeType) -> bool: """ Check if a dtype is a union """ if UnionType is None: return get_origin(dtype) is Union else: return get_origin(dtype) in (Union, UnionType)