Skip to content

Commit 34e038c

Browse files
bokrzesiigcbot
authored andcommitted
Adding support for TargetExtensionTypes for JointMatrixFuncsResolutionPass
In this patch I've added support for TargetExtensionTypes (TET) which are described here: https://reviews.llvm.org/D135202 https://discourse.llvm.org/t/rfc-adding-opaque-types-to-llvm-ir/65326 For JointMatrixFuncsResolutionPass it means that instead of parsing type names we have mature primitive for representing e.g SPIRV types. Before, we were extracting those values `float 16 16 3 3 2` from typename string e.g: ``` %spirv.JointMatrixINTEL._float_16_16_3_3_2 = type opaque %spirv.JointMatrixINTEL._float_8_16_3_3_2 = type opaque %spirv.JointMatrixINTEL._float_32_64_3_3_2 = type opaque ``` Since TET: it is LLVM IR primitive Which is represented like that (right side) ``` %spirv.JointMatrixINTEL._float_16_16_3_3_2 = type { target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) } %spirv.JointMatrixINTEL._float_8_16_3_3_2 = type { target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) } %spirv.JointMatrixINTEL._float_32_64_3_3_2 = type { target("spirv.JointMatrixINTEL", float, 32, 64, 3, 3, 2) } ``` and can be accessed like that (first example): ``` llvm::Type* typeParam = *(targetTy->type_param_begin()); // float a = targetTy->int_params()[0]; // 16 b = targetTy->int_params()[1]; // 16 c = targetTy->int_params()[2]; // 3 d = targetTy->int_params()[3]; // 3 e = targetTy->int_params()[4]; // 2 ``` Additionally, I've extracted some common logic between non-TET and TET handlers e.g for `Use` and `Layout` calculation and GetMatrixTypeName.
1 parent 693cfe5 commit 34e038c

16 files changed

+1613
-97
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 311 additions & 97 deletions
Large diffs are not rendered by default.

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ namespace IGC
8282

8383
bool parseMatrixTypeNameLegacy(const llvm::Type *opaqueType, JointMatrixTypeDescription *outDescription);
8484
bool ParseMatrixTypeName(llvm::Type *opaqueType, JointMatrixTypeDescription *outDescription);
85+
bool ParseMatrixTypeNameNonExtTypeDetails(llvm::Type* opaqueType,
86+
llvm::StringRef name,
87+
bool IsJointMatrix,
88+
JointMatrixTypeDescription* outDescription);
89+
#if LLVM_VERSION_MAJOR >= 16
90+
bool ParseMatrixTypeNameExtTypeDetails(llvm::Type* opaqueType, bool IsJointMatrix, IGC::JointMatrixTypeDescription* outDescription);
91+
#endif
92+
93+
llvm::StringRef GetMatrixTypeName(llvm::Type* opaqueType);
94+
bool SetLayoutFromUse(unsigned int use, IGC::JointMatrixTypeDescription* outDescription);
95+
unsigned GetUseFromLegacyLayout(unsigned int legacyLayout);
8596

8697
unsigned getNumRowsPerWI(const JointMatrixTypeDescription *desc);
8798
llvm::Type *ResolveType(llvm::Type *opaqueType, JointMatrixTypeDescription *outDesc);
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2025 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
;
9+
; REQUIRES: llvm-16-plus
10+
; RUN: igc_opt --typed-pointers --platformpvc --igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
11+
; ------------------------------------------------
12+
; JointMatrixFuncsResolutionPass
13+
; ------------------------------------------------
14+
15+
; CHECK-LABEL: define spir_kernel void @test_fill_store(
16+
; CHECK-SAME: float addrspace(1)* [[DST0:%.*]], float addrspace(1)* [[DST1:%.*]], float addrspace(1)* [[DST2:%.*]]) {
17+
define spir_kernel void @test_fill_store(float addrspace(1)* %dst0, float addrspace(1)* %dst1, float addrspace(1)* %dst2){
18+
; CHECK-NEXT: [[TMP5:%.*]] = alloca { <64 x float>, <64 x float> }
19+
; CHECK-NEXT: [[TMP3:%.*]] = alloca <8 x float>
20+
; CHECK-NEXT: [[TMP1:%.*]] = alloca <16 x float>
21+
; CHECK-NEXT: store <16 x float> <float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00>, <16 x float>* [[TMP1]]
22+
%1 = call spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z26__spirv_CompositeConstructf(float 5.000000e+00)
23+
24+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x float>* [[TMP1]] to i8*
25+
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_global_pi64_v8i8(float addrspace(1)* [[DST0]], i8* [[TMP2]], i64 16, i32 0)
26+
call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS1fPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2liii(float addrspace(1)* %dst0, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) %1, i64 16, i32 0, i32 3, i32 0)
27+
28+
; CHECK-NEXT: store <8 x float> <float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00>, <8 x float>* [[TMP3]]
29+
%2 = call spir_func target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) @_Z26__spirv_CompositeConstructf.1(float 5.000000e+00)
30+
31+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x float>* [[TMP3]] to i8*
32+
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(float addrspace(1)* [[DST1]], i8* [[TMP4]], i64 16, i32 0)
33+
call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS1fPU3AS142__spirv_JointMatrixINTEL__float_8_16_3_3_2liii(float addrspace(1)* %dst1, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) %2, i64 16, i32 0, i32 3, i32 0)
34+
35+
; CHECK-NEXT: store { <64 x float>, <64 x float> } { <64 x float> <float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00>, <64 x float> <float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00, float 5.000000e+00> }, { <64 x float>, <64 x float> }* [[TMP5]]
36+
%3 = call spir_func target("spirv.JointMatrixINTEL", float, 32, 64, 3, 3, 2) @_Z26__spirv_CompositeConstructf.2(float 5.000000e+00)
37+
38+
; CHECK-NEXT: [[TMP6:%.*]] = bitcast { <64 x float>, <64 x float> }* [[TMP5]] to i8*
39+
; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_global_pi64_v8i8(float addrspace(1)* [[DST2]], i8* [[TMP6]], i64 64, i32 0)
40+
call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS1fPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2liii(float addrspace(1)* %dst2, target("spirv.JointMatrixINTEL", float, 32, 64, 3, 3, 2) %3, i64 64, i32 0, i32 3, i32 0)
41+
42+
ret void
43+
}
44+
45+
declare spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z26__spirv_CompositeConstructf(float)
46+
declare spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS1fPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2liii(float addrspace(1)*, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2), i64, i32, i32, i32)
47+
48+
declare spir_func target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) @_Z26__spirv_CompositeConstructf.1(float)
49+
declare spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS1fPU3AS142__spirv_JointMatrixINTEL__float_8_16_3_3_2liii(float addrspace(1)*, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2), i64, i32, i32, i32)
50+
51+
declare spir_func target("spirv.JointMatrixINTEL", float, 32, 64, 3, 3, 2) @_Z26__spirv_CompositeConstructf.2(float)
52+
declare spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS1fPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2liii(float addrspace(1)*, target("spirv.JointMatrixINTEL", float, 32, 64, 3, 3, 2), i64, i32, i32, i32)
53+
54+
!igc.functions = !{!0}
55+
!0 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*)* @test_fill_store, !1}
56+
!1 = !{!2, !3}
57+
!2 = !{!"function_type", i32 0}
58+
!3 = !{!"sub_group_size", i32 16}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2025 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
; REQUIRES: llvm-16-plus
9+
; RUN: igc_opt --typed-pointers -platformpvc -igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
10+
; ------------------------------------------------
11+
; JointMatrixFuncsResolutionPass
12+
; ------------------------------------------------
13+
; Checks for multiple uses of __spirv_AccessChain function call - load plus store
14+
; it must result in extract and then insert an element to the matrix's slice
15+
16+
; CHECK: [[SLICE:%.*]] = load <8 x i16>, <8 x i16>* %{{.*}}, align 8
17+
; CHECK: [[ELEMENT:%.*]] = extractelement <8 x i16> [[SLICE]], i64 4, !joint_matrix_apply
18+
; CHECK: [[ADD:%.*]] = add i16 [[ELEMENT]], 1
19+
; CHECK: [[INSERT:%.*]] = insertelement <8 x i16> [[SLICE]], i16 [[ADD]], i64 4
20+
; CHECK: store <8 x i16> [[INSERT]], <8 x i16>* %{{.*}}, align 8
21+
22+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32"
23+
target triple = "spir64-unknown-unknown"
24+
25+
26+
; Function Attrs: nounwind
27+
define spir_kernel void @_ZTS5logicILm8ELm16EE(i16 addrspace(1)* %arg) {
28+
entry:
29+
%0 = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 )
30+
%1 = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 ) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS143__spirv_CooperativeMatrixKHR__short_3_8_16_0PU3AS1slii(i16 addrspace(1)* %arg, i32 0, i64 64, i32 0)
31+
store target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 ) %1, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 )* %0
32+
%ptr = call spir_func i16 addrspace(4)* @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR._short_3_8_16_0l(target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 )* %0, i64 4)
33+
%extract = load i16, i16 addrspace(4)* %ptr
34+
%add = add i16 %extract, 1
35+
store i16 %add, i16 addrspace(4)* %ptr
36+
ret void
37+
}
38+
39+
; Function Attrs: nounwind
40+
declare spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 ) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS143__spirv_CooperativeMatrixKHR__short_3_8_16_0PU3AS1slii(i16 addrspace(1)* %0, i32 %1, i64 %2, i32 %3)
41+
42+
; Function Attrs: nounwind
43+
declare spir_func i16 addrspace(4)* @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR._short_3_8_16_0l(target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0 )* %0, i64 %1)
44+
45+
!spirv.MemoryModel = !{!0}
46+
!spirv.Source = !{!1}
47+
!spirv.Generator = !{!2}
48+
!igc.functions = !{!3}
49+
50+
!0 = !{i32 2, i32 2}
51+
!1 = !{i32 4, i32 100000}
52+
!2 = !{i16 6, i16 14}
53+
!3 = !{void (i16 addrspace(1)*)* @_ZTS5logicILm8ELm16EE, !4}
54+
!4 = !{!5}
55+
!5 = !{!"function_type", i32 0}

0 commit comments

Comments
 (0)