@@ -155,6 +155,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
155
155
bitwise_or = _two_arg (torch .bitwise_or )
156
156
bitwise_right_shift = _two_arg (torch .bitwise_right_shift )
157
157
bitwise_xor = _two_arg (torch .bitwise_xor )
158
+ copysign = _two_arg (torch .copysign )
158
159
divide = _two_arg (torch .divide )
159
160
# Also a rename. torch.equal does not broadcast
160
161
equal = _two_arg (torch .eq )
@@ -702,9 +703,9 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
702
703
703
704
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' ,
704
705
'newaxis' , 'add' , 'atan2' , 'bitwise_and' , 'bitwise_left_shift' ,
705
- 'bitwise_or' , 'bitwise_right_shift' , 'bitwise_xor' , 'divide ' ,
706
- 'equal ' , 'floor_divide ' , 'greater ' , 'greater_equal ' , 'less ' ,
707
- 'less_equal' , 'logaddexp' , 'multiply' , 'not_equal' , 'pow' ,
706
+ 'bitwise_or' , 'bitwise_right_shift' , 'bitwise_xor' , 'copysign ' ,
707
+ 'divide ' , 'equal ' , 'floor_divide ' , 'greater ' , 'greater_equal ' ,
708
+ 'less' , ' less_equal' , 'logaddexp' , 'multiply' , 'not_equal' , 'pow' ,
708
709
'remainder' , 'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' ,
709
710
'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
710
711
'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
0 commit comments