"""Interface for Dask arrays"""fromtypingimportAny,Iterable,List,Literal,Optional,UnionimportnumpyasnpfrompydanticimportBaseModel,SerializationInfofromnumpydantic.interface.interfaceimportInterface,JsonDictfromnumpydantic.typesimportDtypeType,NDArrayTypetry:fromdask.arrayimportfrom_arrayfromdask.array.coreimportArrayasDaskArrayexceptImportError:# pragma: no coverDaskArray=Nonedef_as_tuple(a_list:Any)->tuple:"""Make a list of list into a tuple of tuples"""returntuple([_as_tuple(item)ifisinstance(item,list)elseitemforitemina_list])
[docs]classDaskJsonDict(JsonDict):""" Round-trip json serialized form of a dask array """type:Literal["dask"]name:strchunks:Iterable[tuple[int,...]]dtype:strvalue:list
[docs]defto_array_input(self)->DaskArray:"""Construct a dask array"""np_array=np.array(self.value,dtype=self.dtype)array=from_array(np_array,name=self.name,chunks=_as_tuple(self.chunks),)returnarray
[docs]classDaskInterface(Interface):""" Interface for Dask :class:`~dask.array.core.Array` """name="dask"input_types=(DaskArray,dict)return_type=DaskArrayjson_model=DaskJsonDict
[docs]@classmethoddefcheck(cls,array:Any)->bool:""" check if array is a dask array """ifDaskArrayisNone:# pragma: no cover - no tests for interface deps atmreturnFalseelifisinstance(array,DaskArray):returnTrueelifisinstance(array,dict):returnDaskJsonDict.is_valid(array)else:returnFalse
[docs]defbefore_validation(self,array:DaskArray)->NDArrayType:""" Try and coerce dicts that should be model objects into the model objects """try:ifissubclass(self.dtype,BaseModel)andisinstance(array.reshape(-1)[0].compute(),dict):def_chunked_to_model(array:np.ndarray)->np.ndarray:def_vectorized_to_model(item:Union[dict,BaseModel])->BaseModel:ifnotisinstance(item,self.dtype):returnself.dtype(**item)else:# pragma: no coverreturnitemreturnnp.vectorize(_vectorized_to_model)(array)array=array.map_blocks(_chunked_to_model,dtype=self.dtype)exceptTypeError:# fine, dtype isn't a typepassreturnarray
[docs]defget_object_dtype(self,array:NDArrayType)->DtypeType:"""Dask arrays require a compute() call to retrieve a single value"""returntype(array.reshape(-1)[0].compute())
[docs]@classmethoddefenabled(cls)->bool:"""check if we successfully imported dask"""returnDaskArrayisnotNone
[docs]@classmethoddefto_json(cls,array:DaskArray,info:Optional[SerializationInfo]=None)->Union[List,DaskJsonDict]:""" Convert an array to a JSON serializable array by first converting to a numpy array and then to a list. .. note:: This is likely a very memory intensive operation if you are using dask for large arrays. This can't be avoided, since the creation of the json string happens in-memory with Pydantic, so you are likely looking for a different method of serialization here using the python object itself rather than its JSON representation. """np_array=np.array(array)as_json=np_array.tolist()ifinfo.round_trip:as_json=DaskJsonDict(type=cls.name,value=as_json,name=array.name,chunks=array.chunks,dtype=str(np_array.dtype),)returnas_json