from abc import ABC, abstractmethod
from collections.abc import Sequence
from functools import reduce
from itertools import product
from operator import ior
from pathlib import Path
from typing import (
TYPE_CHECKING,
Literal,
Union,
)
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field
from numpydantic import NDArray, Shape
from numpydantic.dtype import Float
from numpydantic.interface import Interface
from numpydantic.types import DtypeType, NDArrayType
if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
[docs]
class InterfaceCase(ABC):
"""
An interface test helper that allows a given interface to generate and validate
arrays in one of its formats.
Each instance of "interface test case" should be considered one of the
potentially multiple realizations of a given interface.
If an interface has multiple formats (eg. zarr's different `store` s),
then it should have several test helpers.
"""
@property
@abstractmethod
def interface(self) -> Interface:
"""The interface that this helper is for"""
[docs]
@classmethod
def array_from_case(
cls, case: "ValidationCase", path: Path | None = None
) -> NDArrayType | None:
"""
Generate an array from the given validation case.
Returns ``None`` if an array can't be generated for a specific case.
"""
return cls.make_array(shape=case.shape, dtype=case.dtype, path=path)
[docs]
@classmethod
@abstractmethod
def make_array(
cls,
shape: tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Path | None = None,
array: NDArrayType | None = None,
) -> NDArrayType | None:
"""
Make an array from a shape and dtype, and a path if needed
Args:
shape: shape of the array
dtype: dtype of the array
path: Path, if needed to generate on disk
array: Rather than passing shape and dtype, pass a literal arraylike thing
"""
[docs]
@classmethod
def validate_case(cls, case: "ValidationCase", path: Path) -> bool:
"""
Validate a generated array against the annotation in the validation case.
Kept in the InterfaceCase in case an interface has specific
needs aside from just validating against a model, but typically left as is.
If an array can't be generated for a given case, returns `None`
so that the calling function can know to skip rather than fail the case.
Raises exceptions if validation fails (or succeeds when it shouldn't)
Args:
case (ValidationCase): The validation case to validate.
path (Path): Path to generate arrays into, if any.
Returns:
``True`` if array is valid and was supposed to be,
or invalid and wasn't supposed to be
"""
import pytest
array = cls.array_from_case(case, path)
if array is None:
pytest.skip()
if case.passes:
case.model(array=array)
return True
else:
with pytest.raises(ValidationError):
case.model(array=array)
return True
[docs]
@classmethod
def skip(cls, shape: tuple[int, ...], dtype: DtypeType) -> bool:
"""
Whether a given interface should be skipped for the case
"""
# Assume an interface case is valid for all other cases
return False
_a_shape_type = tuple[int | Literal["*"] | Literal["..."], ...]
[docs]
class ValidationCase(BaseModel):
"""
Test case for validating an array.
Contains both the validating model and the parameterization for an array to
test in a given interface
"""
id: str | None = None
"""
String identifying the validation case
"""
annotation_shape: tuple[int | str, ...] | tuple[tuple[int | str, ...], ...] = (
10,
10,
"*",
"*",
)
"""
Shape to use in computed annotation used to validate against
"""
annotation_dtype: DtypeType | Sequence[DtypeType] = Float
"""
Dtype to use in computed annotation used to validate against
"""
shape: tuple[int, ...] = (10, 10, 2, 2)
"""Shape of the array to validate"""
dtype: type | np.dtype = float
"""Dtype of the array to validate"""
passes: bool = False
"""Whether the validation should pass or not"""
interface: type[InterfaceCase] | None = None
"""The interface test case to generate and validate the array with"""
path: Path | None = None
"""The path to generate arrays into, if any."""
marks: set[str] = Field(default_factory=set)
"""pytest marks to set for this test case"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@computed_field()
def annotation(self) -> NDArray:
"""
Annotation used in the model we validate against
"""
# make a union type if we need to
shape_union = all(
isinstance(s, Sequence) and not isinstance(s, str)
for s in self.annotation_shape
)
dtype_union = isinstance(self.annotation_dtype, Sequence) and all(
isinstance(s, Sequence) for s in self.annotation_dtype
)
if shape_union or dtype_union:
shape_iter = (
self.annotation_shape if shape_union else [self.annotation_shape]
)
dtype_iter = (
self.annotation_dtype if dtype_union else [self.annotation_dtype]
)
annotations: list[type] = []
for shape, dtype in product(shape_iter, dtype_iter):
shape_str = ", ".join([str(i) for i in shape])
annotations.append(NDArray[Shape[shape_str], dtype])
return Union[tuple(annotations)] # noqa: UP007
else:
shape_str = ", ".join([str(i) for i in self.annotation_shape])
return NDArray[Shape[shape_str], self.annotation_dtype]
@computed_field()
def model(self) -> type[BaseModel]:
"""A model with a field ``array`` with the given annotation"""
annotation = self.annotation
class Model(BaseModel):
array: annotation
return Model
@property
def pytest_marks(self) -> list["MarkDecorator"]:
"""
Instantiated pytest marks from :attr:`.ValidationCase.marks`
plus the interface name.
"""
import pytest
marks = self.marks.copy()
if self.interface is not None:
marks.add(self.interface.interface.name)
return [getattr(pytest.mark, m) for m in marks]
[docs]
def validate_case(self, path: Path | None = None) -> bool:
"""
Whether the generated array correctly validated against the annotation,
given the interface
Args:
path (:class:`pathlib.Path`): Directory to generate array into, if on disk.
Raises:
ValueError: if an ``interface`` is missing
"""
if self.interface is None: # pragma: no cover
raise ValueError("Missing an interface")
if path is None:
if self.path:
path = self.path
else: # pragma: no cover
raise ValueError("Missing a path to generate arrays into")
return self.interface.validate_case(self, path)
[docs]
def array(self, path: Path) -> NDArrayType:
"""Generate an array for the validation case if we have an interface to do so"""
if self.interface is None: # pragma: no cover
raise ValueError("Missing an interface")
if path is None: # pragma: no cover
if self.path:
path = self.path
else:
raise ValueError("Missing a path to generate arrays into")
return self.interface.array_from_case(self, path)
[docs]
def merge(
self, other: Union["ValidationCase", Sequence["ValidationCase"]]
) -> "ValidationCase":
"""
Merge two validation cases
Dump both, excluding any unset fields, and merge, preferring `other`.
``valid`` is ``True`` if and only if it is ``True`` in both.
"""
if isinstance(other, Sequence):
return merge_cases(self, *other)
else:
return merge_cases(self, other)
[docs]
def skip(self) -> bool:
"""
Whether this case should be skipped
(eg. due to the interface case being incompatible
with the requested dtype or shape)
"""
return bool(
self.interface is not None and self.interface.skip(self.shape, self.dtype)
)
[docs]
def merge_cases(*args: ValidationCase) -> ValidationCase:
"""
Merge multiple validation cases
"""
if len(args) == 1: # pragma: no cover
return args[0]
dumped = [
m.model_dump(
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
)
for m in args
]
# self_dump = self.model_dump(exclude_unset=True)
# other_dump = other.model_dump(exclude_unset=True)
# dumps might not have set `passes`, use only the ones that have
passes = [v.get("passes") for v in dumped if "passes" in v]
passes = all(passes)
# combine ids if present
ids = "-".join([str(v.get("id")) for v in dumped if "id" in v])
# merge dicts
merged = reduce(ior, dumped, {})
merged["passes"] = passes
merged["id"] = ids
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
return ValidationCase.model_construct(**merged)
[docs]
def merged_product(
*args: Sequence[ValidationCase], conditions: dict = None
) -> list[ValidationCase]:
"""
Generator for the product of the iterators of validation cases,
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
or not.
Examples:
.. code-block:: python
shape_cases = [
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
]
dtype_cases = [
ValidationCase(dtype=float, passes=True, id="float"),
ValidationCase(dtype=int, passes=False, id="int"),
]
iterator = merged_product(shape_cases, dtype_cases))
next(iterator)
# ValidationCase(
# shape=(10, 10, 10),
# dtype=float,
# passes=True,
# id="valid shape-float"
# )
next(iterator)
# ValidationCase(
# shape=(10, 10, 10),
# dtype=int,
# passes=False,
# id="valid shape-int"
# )
"""
iterator = product(*args)
cases = []
for case_tuple in iterator:
case = merge_cases(*case_tuple)
if case.skip():
continue
if conditions:
matching = all([getattr(case, k, None) == v for k, v in conditions.items()])
if not matching:
continue
cases.append(case)
return cases