diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 42a383d5cef8..e81d61c6f6a3 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1174,10 +1174,12 @@ def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: # https://github.com/pytorch/pytorch/issues/91000 stack_trace = node.stack_trace if stack_trace: - m = re.search(r"""File "([^"]+)", line ([0-9]+),""", stack_trace) - if m: - filename, line = m.group(1), int(m.group(2)) - return Location.file(filename, line, col=0, context=self._c) + matches = re.findall(r"""File "([^"]+)", line ([0-9]+),""", stack_trace) + locations = [Location.file(m[0], int(m[1]), col=0, context=self._c) for m in matches] + if len(locations) > 1: + return Location.callsite(locations[-1], locations[-2::-1], context=self._c) + elif len(locations) == 1: + return locations[0] return Location.unknown(context=self._c) def set_symbolic_guards(