Skip to content

Commit

Permalink
Fixing unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: jerryyin <[email protected]>
  • Loading branch information
jerryyin committed Feb 4, 2025
1 parent ba4f482 commit c0e267b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
8 changes: 5 additions & 3 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
problem_size: ProblemSize,
problem_size: Optional[ProblemSize],
) -> ir.Module:
"""Generate a transform dialect spec that applies the compilation info attr."""
pass
Expand All @@ -67,7 +67,7 @@ def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
problem_size: ProblemSize,
problem_size: Optional[ProblemSize],
) -> ir.Module:
contraction_op: ir.Operation = self.get_contraction_operation(ir_module)
lhs_type = ir.ShapedType(contraction_op.operands[0].type)
Expand All @@ -88,8 +88,10 @@ def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
problem_size: ProblemSize,
problem_size: Optional[ProblemSize],
) -> ir.Module:
assert problem_size, "Problem size not found"

conv_op: ir.Operation = self.get_conv_operation(ir_module)
lhs_type = ir.ShapedType(conv_op.operands[0].type)
rhs_type = ir.ShapedType(conv_op.operands[1].type)
Expand Down
5 changes: 3 additions & 2 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
ir_module = ir.Module.parse(module_str, context)

tuner = candidate_gen.ContractionOpInterfaceTuner()
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
td_spec_module = tuner.get_td_spec(ir_module, compilation_info, None)
assert td_spec_module

named_sequence_ops: list[
Expand Down Expand Up @@ -167,7 +167,8 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None:
ir_module = ir.Module.parse(module_str, context)

tuner = candidate_gen.ConvolutionOpInterfaceTuner()
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
problem_size = tuner.get_shapes(str(ir_module).splitlines())
td_spec_module = tuner.get_td_spec(ir_module, compilation_info, problem_size)
assert td_spec_module

named_sequence_ops: list[
Expand Down
12 changes: 12 additions & 0 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([2048, 1280], tuner_ctx.type.f32),
common.DispatchKind.contraction,
common.ContractionDimensions([0], [1], [2]),
None,
None,
None,
None,
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
Expand All @@ -145,6 +149,10 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([2048, 1280], tuner_ctx.type.i32),
common.DispatchKind.contraction,
common.ContractionDimensions([0], [1], [2]),
None,
None,
None,
None,
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
Expand All @@ -166,6 +174,10 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.contraction,
common.ContractionDimensions([1], [2], [3], [0]),
None,
None,
None,
None,
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
Expand Down
50 changes: 50 additions & 0 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
configs = dispatch_constraints.generate_solutions(
tuner_ctx,
Expand Down Expand Up @@ -72,6 +76,10 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext)
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
assert (
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
Expand All @@ -88,6 +96,10 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext)
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
assert (
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
Expand All @@ -104,6 +116,10 @@ def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext)
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
assert (
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
Expand Down Expand Up @@ -146,6 +162,10 @@ def test_adjust_problem_size_for_pipeline(
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
pipeline_options_space = dispatch_constraints.PipelineOptionsSearchSpace(
prefetch_shared_memory=[True],
Expand Down Expand Up @@ -187,13 +207,27 @@ def test_adjust_problem_size_for_pipeline(
lhs_type = common.ShapedType([2, 34, 34, 512], tuner_ctx.type.f16)
rhs_type = common.ShapedType([3, 3, 512, 256], tuner_ctx.type.f16)
res_type = common.ShapedType([2, 32, 32, 256], tuner_ctx.type.f32)
conv_dims = common.ConvolutionDimensions(
batch=[0],
outputImage=[1, 2],
outputChannel=[3],
filterLoop=[4, 5],
inputChannel=[6],
depth=[],
strides=[1, 1],
dilations=[1, 1],
)
conv_problem_size = common.ProblemSize(
conv_size,
lhs_type,
rhs_type,
res_type,
common.DispatchKind.conv,
contraction_dims,
[[0], [1, 4], [2, 5], [6]],
[[4], [5], [6], [3]],
[[0], [1], [2], [3]],
conv_dims,
)
vec_dist_pipeline = (
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
Expand Down Expand Up @@ -265,6 +299,10 @@ def test_generate_tile_and_fuse_constraints_valid_input(
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
# Define input parameters as z3 Ints
m, n, k = (
Expand Down Expand Up @@ -337,6 +375,10 @@ def test_generate_tile_and_fuse_constraints_invalid_input(
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
# Define input parameters as z3 Ints
m, n, k = (
Expand Down Expand Up @@ -399,6 +441,10 @@ def test_generate_vector_distribute_constraints_valid_input(
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
# Define input parameters as z3 Ints
m, n, k = (
Expand Down Expand Up @@ -457,6 +503,10 @@ def test_generate_vector_distribute_constraints_invalid_input(
res_type,
common.DispatchKind.contraction,
contraction_dims,
None,
None,
None,
None,
)
m, n, k = (
[z3.Int("m")],
Expand Down

0 comments on commit c0e267b

Please sign in to comment.