Skip to content

Commit 0ddac3f

Browse files
authored
[TensorRT] Fix dynamic linspace translation, relax dynamic verifier (#537)
Fixes dynamic linspace translation by passing a 1-D static shape tensor. Relaxes linspace op's verifier when the input step dimension is unknown.
1 parent 1af62a4 commit 0ddac3f

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,8 @@ nvinfer1::ILayer *NvInferNetworkEncoder::addFillLayer(
371371
nvinfer1::Dims shapeDims = dynamicShape->getDimensions();
372372
assert(shapeDims.nbDims == 1 && shapeDims.d[0] > 0 &&
373373
"invalid shape tensor shape");
374-
staticShape.nbDims = shapeDims.d[0];
375-
for (int32_t i = 0; i < shapeDims.nbDims; i++)
376-
staticShape.d[i] = 1;
374+
staticShape.nbDims = 1;
375+
staticShape.d[0] = 1;
377376
}
378377
nvinfer1::IFillLayer *layer =
379378
network->addFill(staticShape, fillOperation, elementType);

mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ LogicalResult tensorrt::LinspaceOp::verify() {
117117
if (getStep() == nullptr)
118118
return emitOpError("dynamic `step` must be specified if the result is "
119119
"greater than rank 1");
120-
if (getStep().getType().getDimSize(0) != getType().getRank())
120+
TensorType stepType = getStep().getType();
121+
if (!stepType.isDynamicDim(0) &&
122+
stepType.getDimSize(0) != getType().getRank())
121123
return emitOpError("dynamic `step` type dimension 0 length must be the "
122124
"same size as the rank of the result type");
123125
}

0 commit comments

Comments
 (0)