88from collections .abc import Hashable , Iterable , Mapping , Sequence
99from datetime import timedelta
1010from functools import partial
11- from typing import TYPE_CHECKING , Any , Callable , Literal , NoReturn
11+ from typing import TYPE_CHECKING , Any , Callable , Literal , NoReturn , cast
1212
1313import numpy as np
1414import pandas as pd
6666 PadModeOptions ,
6767 PadReflectOptions ,
6868 QuantileMethods ,
69+ T_DuckArray ,
6970 T_Variable ,
7071 )
7172
@@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError):
8687 # TODO: move this to an xarray.exceptions module?
8788
8889
89- def as_variable (obj , name = None ) -> Variable | IndexVariable :
90+ def as_variable (obj : T_DuckArray | Any , name = None ) -> Variable | IndexVariable :
9091 """Convert an object into a Variable.
9192
9293 Parameters
@@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
142143 elif isinstance (obj , (set , dict )):
143144 raise TypeError (f"variable { name !r} has invalid type { type (obj )!r} " )
144145 elif name is not None :
145- data = as_compatible_data (obj )
146+ data : T_DuckArray = as_compatible_data (obj )
146147 if data .ndim != 1 :
147148 raise MissingDimensionsError (
148149 f"cannot set variable { name !r} with { data .ndim !r} -dimensional data "
@@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data):
230231 return data
231232
232233
233- def as_compatible_data (data , fastpath : bool = False ):
234+ def as_compatible_data (
235+ data : T_DuckArray | ArrayLike , fastpath : bool = False
236+ ) -> T_DuckArray :
234237 """Prepare and wrap data to put in a Variable.
235238
236239 - If data does not have the necessary attributes, convert it to ndarray.
@@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False):
243246 """
244247 if fastpath and getattr (data , "ndim" , 0 ) > 0 :
245248 # can't use fastpath (yet) for scalars
246- return _maybe_wrap_data (data )
249+ return cast ( "T_DuckArray" , _maybe_wrap_data (data ) )
247250
248251 from xarray .core .dataarray import DataArray
249252
@@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False):
252255
253256 if isinstance (data , NON_NUMPY_SUPPORTED_ARRAY_TYPES ):
254257 data = _possibly_convert_datetime_or_timedelta_index (data )
255- return _maybe_wrap_data (data )
258+ return cast ( "T_DuckArray" , _maybe_wrap_data (data ) )
256259
257260 if isinstance (data , tuple ):
258261 data = utils .to_0d_object_array (data )
@@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False):
279282 if not isinstance (data , np .ndarray ) and (
280283 hasattr (data , "__array_function__" ) or hasattr (data , "__array_namespace__" )
281284 ):
282- return data
285+ return cast ( "T_DuckArray" , data )
283286
284287 # validate whether the data is valid data types.
285288 data = np .asarray (data )
@@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
335338
336339 __slots__ = ("_dims" , "_data" , "_attrs" , "_encoding" )
337340
338- def __init__ (self , dims , data , attrs = None , encoding = None , fastpath = False ):
341+ def __init__ (
342+ self ,
343+ dims ,
344+ data : T_DuckArray | ArrayLike ,
345+ attrs = None ,
346+ encoding = None ,
347+ fastpath = False ,
348+ ):
339349 """
340350 Parameters
341351 ----------
@@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
355365 Well-behaved code to serialize a Variable should ignore
356366 unrecognized encoding items.
357367 """
358- self ._data = as_compatible_data (data , fastpath = fastpath )
368+ self ._data : T_DuckArray = as_compatible_data (data , fastpath = fastpath )
359369 self ._dims = self ._parse_dimensions (dims )
360- self ._attrs = None
370+ self ._attrs : dict [ Any , Any ] | None = None
361371 self ._encoding = None
362372 if attrs is not None :
363373 self .attrs = attrs
@@ -410,7 +420,7 @@ def _in_memory(self):
410420 )
411421
412422 @property
413- def data (self ) -> Any :
423+ def data (self : T_Variable ) :
414424 """
415425 The Variable's data as an array. The underlying array type
416426 (e.g. dask, sparse, pint) is preserved.
@@ -429,12 +439,12 @@ def data(self) -> Any:
429439 return self .values
430440
431441 @data .setter
432- def data (self , data ) :
442+ def data (self : T_Variable , data : T_DuckArray | ArrayLike ) -> None :
433443 data = as_compatible_data (data )
434- if data .shape != self .shape :
444+ if data .shape != self .shape : # type: ignore[attr-defined]
435445 raise ValueError (
436446 f"replacement data must match the Variable's shape. "
437- f"replacement data has shape { data .shape } ; Variable has shape { self .shape } "
447+ f"replacement data has shape { data .shape } ; Variable has shape { self .shape } " # type: ignore[attr-defined]
438448 )
439449 self ._data = data
440450
@@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable:
9961006 return self ._replace (encoding = {})
9971007
9981008 def copy (
999- self : T_Variable , deep : bool = True , data : ArrayLike | None = None
1009+ self : T_Variable , deep : bool = True , data : T_DuckArray | ArrayLike | None = None
10001010 ) -> T_Variable :
10011011 """Returns a copy of this object.
10021012
@@ -1058,24 +1068,26 @@ def copy(
10581068 def _copy (
10591069 self : T_Variable ,
10601070 deep : bool = True ,
1061- data : ArrayLike | None = None ,
1071+ data : T_DuckArray | ArrayLike | None = None ,
10621072 memo : dict [int , Any ] | None = None ,
10631073 ) -> T_Variable :
10641074 if data is None :
1065- ndata = self ._data
1075+ data_old = self ._data
10661076
1067- if isinstance (ndata , indexing .MemoryCachedArray ):
1077+ if isinstance (data_old , indexing .MemoryCachedArray ):
10681078 # don't share caching between copies
1069- ndata = indexing .MemoryCachedArray (ndata .array )
1079+ ndata = indexing .MemoryCachedArray (data_old .array )
1080+ else :
1081+ ndata = data_old
10701082
10711083 if deep :
10721084 ndata = copy .deepcopy (ndata , memo )
10731085
10741086 else :
10751087 ndata = as_compatible_data (data )
1076- if self .shape != ndata .shape :
1088+ if self .shape != ndata .shape : # type: ignore[attr-defined]
10771089 raise ValueError (
1078- f"Data shape { ndata .shape } must match shape of object { self .shape } "
1090+ f"Data shape { ndata .shape } must match shape of object { self .shape } " # type: ignore[attr-defined]
10791091 )
10801092
10811093 attrs = copy .deepcopy (self ._attrs , memo ) if deep else copy .copy (self ._attrs )
@@ -1248,11 +1260,11 @@ def chunk(
12481260 inline_array = inline_array ,
12491261 )
12501262
1251- data = self ._data
1252- if chunkmanager .is_chunked_array (data ):
1253- data = chunkmanager .rechunk (data , chunks ) # type: ignore[arg-type]
1263+ data_old = self ._data
1264+ if chunkmanager .is_chunked_array (data_old ):
1265+ data_chunked = chunkmanager .rechunk (data_old , chunks ) # type: ignore[arg-type]
12541266 else :
1255- if isinstance (data , indexing .ExplicitlyIndexed ):
1267+ if isinstance (data_old , indexing .ExplicitlyIndexed ):
12561268 # Unambiguously handle array storage backends (like NetCDF4 and h5py)
12571269 # that can't handle general array indexing. For example, in netCDF4 you
12581270 # can do "outer" indexing along two dimensions independent, which works
@@ -1261,20 +1273,22 @@ def chunk(
12611273 # Using OuterIndexer is a pragmatic choice: dask does not yet handle
12621274 # different indexing types in an explicit way:
12631275 # https://github.com/dask/dask/issues/2883
1264- data = indexing .ImplicitToExplicitIndexingAdapter (
1265- data , indexing .OuterIndexer
1276+ ndata = indexing .ImplicitToExplicitIndexingAdapter (
1277+ data_old , indexing .OuterIndexer
12661278 )
1279+ else :
1280+ ndata = data_old
12671281
12681282 if utils .is_dict_like (chunks ):
1269- chunks = tuple (chunks .get (n , s ) for n , s in enumerate (data .shape ))
1283+ chunks = tuple (chunks .get (n , s ) for n , s in enumerate (ndata .shape ))
12701284
1271- data = chunkmanager .from_array (
1272- data ,
1285+ data_chunked = chunkmanager .from_array (
1286+ ndata ,
12731287 chunks , # type: ignore[arg-type]
12741288 ** _from_array_kwargs ,
12751289 )
12761290
1277- return self ._replace (data = data )
1291+ return self ._replace (data = data_chunked )
12781292
12791293 def to_numpy (self ) -> np .ndarray :
12801294 """Coerces wrapped data to numpy and returns a numpy.ndarray"""
0 commit comments