From 5d83f60c0914d6e6d79254cf2e24ec9aa37f226d Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 4 Mar 2025 07:44:27 -0800 Subject: [PATCH] Use all frames of the stack trace when importing --- python/torch_mlir/extras/fx_importer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 42a383d5cef8..e81d61c6f6a3 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1174,10 +1174,12 @@ def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: # https://github.com/pytorch/pytorch/issues/91000 stack_trace = node.stack_trace if stack_trace: - m = re.search(r"""File "([^"]+)", line ([0-9]+),""", stack_trace) - if m: - filename, line = m.group(1), int(m.group(2)) - return Location.file(filename, line, col=0, context=self._c) + matches = re.findall(r"""File "([^"]+)", line ([0-9]+),""", stack_trace) + locations = [Location.file(m[0], int(m[1]), col=0, context=self._c) for m in matches] + if len(locations) > 1: + return Location.callsite(locations[-1], locations[-2::-1], context=self._c) + elif len(locations) == 1: + return locations[0] return Location.unknown(context=self._c) def set_symbolic_guards(