@@ -361,8 +361,10 @@ def std(x: array,
361
361
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
362
362
# implement it here for now.
363
363
364
- # if isinstance(correction, float):
365
- # correction = int(correction)
364
+ if isinstance (correction , float ):
365
+ _correction = int (correction )
366
+ if correction != _correction :
367
+ raise NotImplementedError ("float correction in torch std() is not yet supported" )
366
368
367
369
# https://github.com/pytorch/pytorch/issues/29137
368
370
if axis == ():
@@ -372,10 +374,10 @@ def std(x: array,
372
374
if axis is None :
373
375
# torch doesn't support keepdims with axis=None
374
376
# (https://github.com/pytorch/pytorch/issues/71209)
375
- res = torch .std (x , tuple (range (x .ndim )), correction = correction , ** kwargs )
377
+ res = torch .std (x , tuple (range (x .ndim )), correction = _correction , ** kwargs )
376
378
res = _axis_none_keepdims (res , x .ndim , keepdims )
377
379
return res
378
- return torch .std (x , axis , correction = correction , keepdims = keepdims , ** kwargs )
380
+ return torch .std (x , axis , correction = _correction , keepdims = keepdims , ** kwargs )
379
381
380
382
def var (x : array ,
381
383
/ ,
@@ -519,6 +521,28 @@ def full(shape: Union[int, Tuple[int, ...]],
519
521
520
522
return torch .full (shape , fill_value , dtype = dtype , device = device , ** kwargs )
521
523
524
+ # ones, zeros, and empty do not accept shape as a keyword argument
525
+ def ones (shape : Union [int , Tuple [int , ...]],
526
+ * ,
527
+ dtype : Optional [Dtype ] = None ,
528
+ device : Optional [Device ] = None ,
529
+ ** kwargs ) -> array :
530
+ return torch .ones (shape , dtype = dtype , device = device , ** kwargs )
531
+
532
+ def zeros (shape : Union [int , Tuple [int , ...]],
533
+ * ,
534
+ dtype : Optional [Dtype ] = None ,
535
+ device : Optional [Device ] = None ,
536
+ ** kwargs ) -> array :
537
+ return torch .zeros (shape , dtype = dtype , device = device , ** kwargs )
538
+
539
+ def empty (shape : Union [int , Tuple [int , ...]],
540
+ * ,
541
+ dtype : Optional [Dtype ] = None ,
542
+ device : Optional [Device ] = None ,
543
+ ** kwargs ) -> array :
544
+ return torch .empty (shape , dtype = dtype , device = device , ** kwargs )
545
+
522
546
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
523
547
def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
524
548
return torch .unsqueeze (x , axis )
@@ -585,7 +609,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
585
609
'logaddexp' , 'multiply' , 'not_equal' , 'pow' , 'remainder' ,
586
610
'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
587
611
'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip' , 'roll' ,
588
- 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' ,
589
- 'expand_dims ' , 'astype ' , 'broadcast_arrays ' , 'unique_all ' ,
590
- 'unique_counts' , 'unique_inverse' , 'unique_values' ,
612
+ 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
613
+ 'zeros ' , 'empty ' , 'expand_dims ' , 'astype' , 'broadcast_arrays ' ,
614
+ 'unique_all' , ' unique_counts' , 'unique_inverse' , 'unique_values' ,
591
615
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
0 commit comments