77import  warnings 
88from  collections .abc  import  Sequence 
99from  types  import  ModuleType 
10- from  typing  import  cast 
10+ from  typing  import  TYPE_CHECKING ,  cast 
1111
1212from  ._at  import  at 
1313from  ._utils  import  _compat , _helpers 
@@ -375,8 +375,8 @@ def expand_dims(
375375
376376
377377def  isclose (
378-     a : Array ,
379-     b : Array ,
378+     a : Array   |   complex ,
379+     b : Array   |   complex ,
380380    * ,
381381    rtol : float  =  1e-05 ,
382382    atol : float  =  1e-08 ,
@@ -385,6 +385,10 @@ def isclose(
385385) ->  Array :  # numpydoc ignore=PR01,RT01 
386386    """See docstring in array_api_extra._delegation.""" 
387387    a , b  =  asarrays (a , b , xp = xp )
388+     # FIXME https://github.com/microsoft/pyright/issues/10085 
389+     if  TYPE_CHECKING :  # pragma: nocover 
390+         assert  _compat .is_array_api_obj (a )
391+         assert  _compat .is_array_api_obj (b )
388392
389393    a_inexact  =  xp .isdtype (a .dtype , ("real floating" , "complex floating" ))
390394    b_inexact  =  xp .isdtype (b .dtype , ("real floating" , "complex floating" ))
@@ -419,7 +423,13 @@ def isclose(
419423    return  xp .abs (a  -  b ) <=  (atol  +  xp .abs (b ) //  nrtol )
420424
421425
422- def  kron (a : Array , b : Array , / , * , xp : ModuleType  |  None  =  None ) ->  Array :
426+ def  kron (
427+     a : Array  |  complex ,
428+     b : Array  |  complex ,
429+     / ,
430+     * ,
431+     xp : ModuleType  |  None  =  None ,
432+ ) ->  Array :
423433    """ 
424434    Kronecker product of two arrays. 
425435
@@ -495,9 +505,16 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
495505    if  xp  is  None :
496506        xp  =  array_namespace (a , b )
497507    a , b  =  asarrays (a , b , xp = xp )
508+     # FIXME https://github.com/microsoft/pyright/issues/10085 
509+     if  TYPE_CHECKING :  # pragma: nocover 
510+         assert  _compat .is_array_api_obj (a )
511+         assert  _compat .is_array_api_obj (b )
498512
499513    singletons  =  (1 ,) *  (b .ndim  -  a .ndim )
500514    a  =  xp .broadcast_to (a , singletons  +  a .shape )
515+     # FIXME https://github.com/microsoft/pyright/issues/10085 
516+     if  TYPE_CHECKING :  # pragma: nocover 
517+         assert  _compat .is_array_api_obj (a )
501518
502519    nd_b , nd_a  =  b .ndim , a .ndim 
503520    nd_max  =  max (nd_b , nd_a )
@@ -614,8 +631,8 @@ def pad(
614631
615632
616633def  setdiff1d (
617-     x1 : Array ,
618-     x2 : Array ,
634+     x1 : Array   |   complex ,
635+     x2 : Array   |   complex ,
619636    / ,
620637    * ,
621638    assume_unique : bool  =  False ,
@@ -628,7 +645,7 @@ def setdiff1d(
628645
629646    Parameters 
630647    ---------- 
631-     x1 : array 
648+     x1 : array | int | float | complex | bool  
632649        Input array. 
633650    x2 : array 
634651        Input comparison array. 
@@ -665,6 +682,11 @@ def setdiff1d(
665682    else :
666683        x1  =  xp .unique_values (x1 )
667684        x2  =  xp .unique_values (x2 )
685+ 
686+     # FIXME https://github.com/microsoft/pyright/issues/10085 
687+     if  TYPE_CHECKING :  # pragma: nocover 
688+         assert  _compat .is_array_api_obj (x1 )
689+ 
668690    return  x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
669691
670692
0 commit comments