diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index de84548da..f3d1ba5fb 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -41,7 +41,7 @@ def get_td_spec( self, ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, - problem_size: ProblemSize, + problem_size: Optional[ProblemSize], ) -> ir.Module: """Generate a transform dialect spec that applies the compilation info attr.""" pass @@ -67,7 +67,7 @@ def get_td_spec( self, ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, - problem_size: ProblemSize, + problem_size: Optional[ProblemSize], ) -> ir.Module: contraction_op: ir.Operation = self.get_contraction_operation(ir_module) lhs_type = ir.ShapedType(contraction_op.operands[0].type) @@ -88,8 +88,10 @@ def get_td_spec( self, ir_module: ir.Module, compilation_info: iree_codegen.CompilationInfoAttr, - problem_size: ProblemSize, + problem_size: Optional[ProblemSize], ) -> ir.Module: + assert problem_size, "Problem size not found" + conv_op: ir.Operation = self.get_conv_operation(ir_module) lhs_type = ir.ShapedType(conv_op.operands[0].type) rhs_type = ir.ShapedType(conv_op.operands[1].type) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 6a62e90e4..a65265d0e 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -85,7 +85,7 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: ir_module = ir.Module.parse(module_str, context) tuner = candidate_gen.ContractionOpInterfaceTuner() - td_spec_module = tuner.get_td_spec(ir_module, compilation_info) + td_spec_module = tuner.get_td_spec(ir_module, compilation_info, None) assert td_spec_module named_sequence_ops: list[ @@ -167,7 +167,8 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None: ir_module = ir.Module.parse(module_str, context) tuner = candidate_gen.ConvolutionOpInterfaceTuner() - td_spec_module = tuner.get_td_spec(ir_module, compilation_info) + problem_size = tuner.get_shapes(str(ir_module).splitlines()) + td_spec_module = tuner.get_td_spec(ir_module, compilation_info, problem_size) assert td_spec_module named_sequence_ops: list[ diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index a6c71026d..83b671d55 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -125,6 +125,10 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.contraction, common.ContractionDimensions([0], [1], [2]), + None, + None, + None, + None, ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, @@ -145,6 +149,10 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.contraction, common.ContractionDimensions([0], [1], [2]), + None, + None, + None, + None, ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, @@ -166,6 +174,10 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.contraction, common.ContractionDimensions([1], [2], [3], [0]), + None, + None, + None, + None, ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 8bd743bc9..b90703273 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -43,6 +43,10 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) configs = dispatch_constraints.generate_solutions( tuner_ctx, @@ -72,6 +76,10 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( @@ -88,6 +96,10 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( @@ -104,6 +116,10 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( @@ -146,6 +162,10 @@ def test_adjust_problem_size_for_pipeline( res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) pipeline_options_space = dispatch_constraints.PipelineOptionsSearchSpace( prefetch_shared_memory=[True], @@ -187,6 +207,16 @@ def test_adjust_problem_size_for_pipeline( lhs_type = common.ShapedType([2, 34, 34, 512], tuner_ctx.type.f16) rhs_type = common.ShapedType([3, 3, 512, 256], tuner_ctx.type.f16) res_type = common.ShapedType([2, 32, 32, 256], tuner_ctx.type.f32) + conv_dims = common.ConvolutionDimensions( + batch=[0], + outputImage=[1, 2], + outputChannel=[3], + filterLoop=[4, 5], + inputChannel=[6], + depth=[], + strides=[1, 1], + dilations=[1, 1], + ) conv_problem_size = common.ProblemSize( conv_size, lhs_type, @@ -194,6 +224,10 @@ def test_adjust_problem_size_for_pipeline( res_type, common.DispatchKind.conv, contraction_dims, + [[0], [1, 4], [2, 5], [6]], + [[4], [5], [6], [3]], + [[0], [1], [2], [3]], + conv_dims, ) vec_dist_pipeline = ( iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute @@ -265,6 +299,10 @@ def test_generate_tile_and_fuse_constraints_valid_input( res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) # Define input parameters as z3 Ints m, n, k = ( @@ -337,6 +375,10 @@ def test_generate_tile_and_fuse_constraints_invalid_input( res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) # Define input parameters as z3 Ints m, n, k = ( @@ -399,6 +441,10 @@ def test_generate_vector_distribute_constraints_valid_input( res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) # Define input parameters as z3 Ints m, n, k = ( @@ -457,6 +503,10 @@ def test_generate_vector_distribute_constraints_invalid_input( res_type, common.DispatchKind.contraction, contraction_dims, + None, + None, + None, + None, ) m, n, k = ( [z3.Int("m")],