18
18
19
19
import org .tensorflow .ndarray .NdArray ;
20
20
import org .tensorflow .ndarray .NdArraySequence ;
21
+ import org .tensorflow .ndarray .Shape ;
21
22
import org .tensorflow .ndarray .impl .AbstractNdArray ;
22
23
import org .tensorflow .ndarray .impl .dimension .RelativeDimensionalSpace ;
23
24
import org .tensorflow .ndarray .impl .sequence .FastElementSequence ;
@@ -43,18 +44,29 @@ public NdArraySequence<U> elements(int dimensionIdx) {
43
44
DimensionalSpace elemDims = dimensions ().from (dimensionIdx + 1 );
44
45
try {
45
46
DataBufferWindow <? extends DataBuffer <T >> elemWindow = buffer ().window (elemDims .physicalSize ());
46
- U element = instantiate (elemWindow .buffer (), elemDims );
47
+ U element = instantiateView (elemWindow .buffer (), elemDims );
47
48
return new FastElementSequence (this , dimensionIdx , element , elemWindow );
48
49
} catch (UnsupportedOperationException e ) {
49
50
// If buffer windows are not supported, fallback to slicing (and slower) sequence
50
51
return new SlicingElementSequence <>(this , dimensionIdx , elemDims );
51
52
}
52
53
}
53
54
55
+ @ Override
56
+ public U withShape (Shape shape ) {
57
+ if (shape == null || shape .isUnknown () || shape .size () != this .shape ().size ()) {
58
+ throw new IllegalArgumentException ("Shape " + shape + " cannot be used to reshape ndarray of shape " + this .shape ());
59
+ }
60
+ if (shape .equals (this .shape ())) {
61
+ return (U )this ;
62
+ }
63
+ return instantiateView (buffer (), DimensionalSpace .create (shape ));
64
+ }
65
+
54
66
@ Override
55
67
public U slice (long position , DimensionalSpace sliceDimensions ) {
56
68
DataBuffer <T > sliceBuffer = buffer ().slice (position , sliceDimensions .physicalSize ());
57
- return instantiate (sliceBuffer , sliceDimensions );
69
+ return instantiateView (sliceBuffer , sliceDimensions );
58
70
}
59
71
60
72
@ Override
@@ -147,7 +159,7 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) {
147
159
148
160
abstract protected DataBuffer <T > buffer ();
149
161
150
- abstract U instantiate (DataBuffer <T > buffer , DimensionalSpace dimensions );
162
+ abstract U instantiateView (DataBuffer <T > buffer , DimensionalSpace dimensions );
151
163
152
164
long positionOf (long [] coords , boolean isValue ) {
153
165
if (coords == null || coords .length == 0 ) {
0 commit comments