5656from torch ._inductor .cudagraph_utils import BoxedDeviceIndex , PlaceholderInfo
5757from torch ._inductor .debug import save_args_for_compile_fx_inner
5858from torch ._inductor .output_code import (
59+ CompiledAOTI ,
5960 CompiledFxGraph ,
6061 get_expanded_dims ,
6162 index_expanded_dims ,
63+ OutputCode ,
6264)
6365from torch ._inductor .runtime .runtime_utils import cache_dir
6466from torch ._inductor .utils import (
@@ -509,15 +511,15 @@ def __call__(
509511 gm : GraphModule ,
510512 example_inputs : Sequence [InputType ],
511513 ** kwargs : Unpack [_CompileFxKwargs ],
512- ) -> Union [ CompiledFxGraph , str ] :
514+ ) -> OutputCode :
513515 ...
514516
515517
516518def compile_fx_inner (
517519 gm : GraphModule ,
518520 example_inputs : Sequence [InputType ],
519521 ** kwargs : Unpack [_CompileFxKwargs ],
520- ) -> Union [ CompiledFxGraph , str ] :
522+ ) -> OutputCode :
521523 kwargs .setdefault ("cudagraphs" , None )
522524 kwargs .setdefault ("static_input_idxs" , ())
523525 kwargs .setdefault ("is_backward" , False )
@@ -570,7 +572,7 @@ def _compile_fx_inner(
570572 gm : GraphModule ,
571573 example_inputs : Sequence [InputType ],
572574 ** graph_kwargs : Unpack [_CompileFxKwargs ],
573- ) -> Union [ CompiledFxGraph , str ] :
575+ ) -> OutputCode :
574576 """
575577 Inductor API that compiles a single graph.
576578
@@ -630,11 +632,7 @@ def _compile_fx_inner(
630632 ):
631633 input ._is_inductor_static = True # type: ignore[attr-defined]
632634
633- # TODO: Remove this short circuit once types are unified here
634- if aot_mode :
635- return fx_codegen_and_compile (gm , example_inputs , inputs_to_check , ** graph_kwargs ) # type: ignore[assignment]
636-
637- mb_compiled_graph : Optional [CompiledFxGraph ] = None
635+ mb_compiled_graph : Optional [OutputCode ] = None
638636 key_info = None
639637 cache_info = None
640638 remote_cache = None
@@ -668,31 +666,28 @@ def _compile_fx_inner(
668666 # determined the input is uncacheable)
669667 if cache_info is None or cache_info ["cache_state" ] == "bypass" :
670668 assert mb_compiled_graph is None
671- r = fx_codegen_and_compile (
669+ mb_compiled_graph = fx_codegen_and_compile (
672670 gm , example_inputs , inputs_to_check , ** graph_kwargs
673671 )
674- assert not isinstance (r , str ) # due to aot test
675- mb_compiled_graph = r
676672
677673 # CACHE MISS: Compile the graph and save to cache
678674 elif cache_info ["cache_state" ] == "miss" :
679675 assert mb_compiled_graph is None
680676 assert key_info is not None
681677 TritonBundler .begin_compile ()
682678 try :
683- r = fx_codegen_and_compile (
679+ mb_compiled_graph = fx_codegen_and_compile (
684680 gm , example_inputs , inputs_to_check , ** graph_kwargs
685681 )
686- assert not isinstance (r , str ) # due to aot test
687- mb_compiled_graph = r
688682 assert mb_compiled_graph is not None
689683 mb_compiled_graph ._time_taken_ns = time .time_ns () - start_time
690684 cache_key = key_info [0 ]
691685 mb_compiled_graph ._fx_graph_cache_key = cache_key
692686 (
693- mb_compiled_graph . _triton_bundle ,
687+ triton_bundle ,
694688 triton_bundler_meta ,
695689 ) = TritonBundler .collect ()
690+ mb_compiled_graph .set_triton_bundle (triton_bundle )
696691 finally :
697692 TritonBundler .end_compile ()
698693 if triton_bundler_meta is not None :
@@ -782,7 +777,7 @@ def fx_codegen_and_compile(
782777 # in explicitly because it's nontrivial to compute
783778 inputs_to_check : Sequence [int ],
784779 ** graph_kwargs : Unpack [_CompileFxKwargs ],
785- ) -> Union [ CompiledFxGraph , str ] :
780+ ) -> OutputCode :
786781 # Sorry about the mess, we need graph_kwargs to continue to be able
787782 # to propagate it further on
788783 # TODO: _CompileFxKwargs actually has stronger types than in the
@@ -979,6 +974,10 @@ def log_graph_runnable() -> str:
979974
980975 _check_triton_bf16_support (graph )
981976
977+ # TODO: The switching between AOT mode and not here is a bit
978+ # messy, but it's localized to the block of code below so I'm
979+ # not going to touch it for now
980+
982981 compiled_fn : Any
983982
984983 with dynamo_timed (
@@ -1058,8 +1057,10 @@ def log_graph_runnable() -> str:
10581057 V .graph .disable_cudagraphs_reason = disable
10591058
10601059 if V .aot_compilation is True :
1061- return compiled_fn
1060+ assert isinstance (compiled_fn , (str , list ))
1061+ return CompiledAOTI (compiled_fn )
10621062
1063+ # TODO: Hoist this above V.aot_compilation
10631064 if cudagraphs and not V .graph .disable_cudagraphs_reason :
10641065 from torch ._inductor .cudagraph_utils import (
10651066 check_lowering_disable_cudagraph ,
@@ -1069,7 +1070,7 @@ def log_graph_runnable() -> str:
10691070 check_lowering_disable_cudagraph (V .graph .device_node_mapping )
10701071 )
10711072
1072- compiled_graph = CompiledFxGraph (
1073+ return CompiledFxGraph (
10731074 compiled_fn ,
10741075 graph ,
10751076 gm ,
@@ -1085,8 +1086,6 @@ def log_graph_runnable() -> str:
10851086 boxed_forward_device_index ,
10861087 )
10871088
1088- return compiled_graph
1089-
10901089
10911090def get_input_idxs_to_check (
10921091 inputs : Sequence [InputType ],
@@ -1326,11 +1325,9 @@ def compile_fx_aot(
13261325 config_patches = config_patches ,
13271326 )
13281327
1329- assert isinstance (compiled_artifacts , str ) or (
1330- isinstance (compiled_artifacts , list )
1331- and isinstance (compiled_artifacts [0 ], str )
1332- )
1333- return compiled_artifacts
1328+ assert isinstance (compiled_artifacts , CompiledAOTI )
1329+
1330+ return compiled_artifacts .filename
13341331
13351332
13361333_graph_counter = count (0 )
@@ -1487,7 +1484,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
14871484def compile_fx (
14881485 model_ : GraphModule ,
14891486 example_inputs_ : Sequence [InputType ],
1490- inner_compile : Callable [..., Any ] = compile_fx_inner ,
1487+ inner_compile : Callable [..., OutputCode ] = compile_fx_inner ,
14911488 config_patches : Optional [Dict [str , Any ]] = None ,
14921489 decompositions : Optional [Dict [OpOverload , Callable [..., Any ]]] = None ,
14931490) -> Union [Callable [[List [object ]], Sequence [torch .Tensor ]], str , List [str ]]:
@@ -1631,7 +1628,7 @@ def fw_compiler_base(
16311628 model : GraphModule ,
16321629 example_inputs : List [InputType ],
16331630 is_inference : bool ,
1634- ) -> CompiledFxGraph :
1631+ ) -> OutputCode :
16351632 with dynamo_utils .dynamo_timed ("compile_fx.<locals>.fw_compiler_base" ):
16361633 if is_inference :
16371634 # partition_fn won't be called
@@ -1737,7 +1734,7 @@ def partition_fn(
17371734 @compile_time_strobelight_meta (phase_name = "backward" )
17381735 def bw_compiler (
17391736 model : GraphModule , example_inputs : List [InputType ]
1740- ) -> Union [ CompiledFxGraph , str ] :
1737+ ) -> OutputCode :
17411738 from torch ._dynamo .convert_frame import compile_lock
17421739
17431740 with dynamo_utils .dynamo_timed (
0 commit comments