@@ -62,21 +62,21 @@ def inplace_normal(input, *args):
6262__all__ .append ('inplace_normal' )
6363
6464def getitem (input , slice ):
65- out = input . numpy ( )[slice ]
65+ out = np . zeros ( input . shape )[slice ]
6666 out = Tensor_ (init = 'none' , shape = out .shape , dtype = input .dtype )
6767 return mindtorch .Tensor (out )
6868
6969__all__ .append ('getitem' )
7070
71- def sub (input , other , alpha ):
71+ def sub (input , other , alpha = 1.0 ):
7272 if isinstance (input , mindtorch .Tensor ):
7373 return input
7474 return other
7575
7676__all__ .append ('sub' )
7777
7878def pad_v3 (input , pad , mode , value ):
79- out = np .pad (input . numpy ( ), pad , mode , constant_values = value )
79+ out = np .pad (np . zeros ( input . shape ), pad , mode , constant_values = value )
8080 out = Tensor_ (init = 'none' , shape = out .shape , dtype = input .dtype )
8181 return mindtorch .Tensor (out )
8282
@@ -94,7 +94,7 @@ def cast(input, dtype):
9494__all__ .append ('cast' )
9595
9696def index_select (input , dim , index ):
97- out = np .take (input . numpy ( ), index . numpy ( ), dim )
97+ out = np .take (np . zeros ( input . shape ), np . zeros ( index . shape , dtype = np . int64 ), dim )
9898 out = Tensor_ (init = 'none' , shape = out .shape , dtype = input .dtype )
9999 return mindtorch .Tensor (out )
100100
@@ -146,6 +146,9 @@ def tril(input, k):
146146__all__ .append ('tril' )
147147
148148def reshape (input , shape ):
149+ if - 1 in shape :
150+ out = np .zeros (input .shape ).reshape (shape )
151+ shape = out .shape
149152 out = Tensor_ (init = 'none' , shape = tuple (shape ), dtype = input .dtype )
150153 return mindtorch .Tensor (out )
151154
@@ -414,4 +417,20 @@ def pad(input, pad, mode='constant', value=None):
414417 raise ValueError ('pad size must be 2, 4 or 6' )
415418
416419 out = Tensor_ (init = 'none' , shape = new_size , dtype = input .dtype )
420+ return mindtorch .Tensor (out )
421+
422+ def setitem (self , slice , value ):
423+ return self
424+
425+ def meshgrid (args , lambd ):
426+ res = np .meshgrid (* args , indexing = lambd )
427+ outs = ()
428+ for r in res :
429+ out = Tensor_ (init = 'none' , shape = r .shape , dtype = args [0 ].dtype )
430+ out = mindtorch .Tensor (out )
431+ outs += (out ,)
432+ return outs
433+
434+ def permute (input , dims ):
435+ out = Tensor_ (init = 'none' , shape = dims , dtype = input .dtype )
417436 return mindtorch .Tensor (out )
0 commit comments