File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
py/torch_tensorrt/dynamo/lowering/passes Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -99,8 +99,17 @@ class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
9999 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
100100 super ().__init__ (* args , ** kwargs )
101101
102- # TODO: Update this function when quantization is added
103102 def is_impure (self , node : torch .fx .node .Node ) -> bool :
104- if node .target in (torch .ops .tensorrt .quantize_op .default ,):
103+ # Set of known quantization ops to be excluded from constant folding.
104+ # Currently, we exclude all quantization ops coming from modelopt library.
105+ quantization_ops = {}
106+ try :
107+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
108+ import modelopt .torch .quantization as mtq
109+ assert torch .ops .tensorrt .quantize_op .default
110+ quantization_ops .add (torch .ops .tensorrt .quantize_op .default )
111+ except Exception as e :
112+ pass
113+ if quantization_ops and node .target in quantization_ops :
105114 return True
106115 return False
You can’t perform that action at this time.
0 commit comments