Skip to content

Commit ae0b3c0

Browse files
committed
Remove _OVERLOAD layer
1 parent 9d5d769 commit ae0b3c0

File tree

3 files changed

+46
-255
lines changed

3 files changed

+46
-255
lines changed

src/trans/gpu/algor/hicblas_mod.F90

-233
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,8 @@
77
! nor does it submit to any jurisdiction.
88
!
99

10-
#if defined CUDAGPU
11-
#define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM
12-
#define OPENACC_LIB OPENACC
13-
#endif
14-
1510
MODULE HICBLAS_MOD
1611

17-
USE EC_PARKIND, ONLY: JPIM, JPRM, JPRD, JPIB
18-
USE GROWING_ALLOCATOR_MOD, ONLY: GROWING_ALLOCATION_TYPE
19-
#ifdef ACCGPU
20-
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM
21-
#endif
22-
#ifdef OMPGPU
23-
#endif
24-
2512
IMPLICIT NONE
2613

2714
INTERFACE
@@ -118,224 +105,4 @@ SUBROUTINE HIP_SGEMM_GROUPED( &
118105
END SUBROUTINE HIP_SGEMM_GROUPED
119106
END INTERFACE
120107

121-
CONTAINS
122-
123-
SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD( &
124-
& TRANSA, TRANSB, &
125-
& M, N, K, &
126-
& ALPHA, &
127-
& AARRAY, LDA, STRIDEA, &
128-
& BARRAY, LDB, STRIDEB, &
129-
& BETA, &
130-
& CARRAY, LDC, STRIDEC, &
131-
& BATCHCOUNT, STREAM, ALLOC)
132-
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
133-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
134-
INTEGER(KIND=JPIM) :: M
135-
INTEGER(KIND=JPIM) :: N
136-
INTEGER(KIND=JPIM) :: K
137-
REAL(KIND=JPRD) :: ALPHA
138-
REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
139-
INTEGER(KIND=JPIM) :: LDA
140-
INTEGER(KIND=JPIM) :: STRIDEA
141-
REAL(KIND=JPRD), DIMENSION(:,:) :: BARRAY
142-
INTEGER(KIND=JPIM) :: LDB
143-
INTEGER(KIND=JPIM) :: STRIDEB
144-
REAL(KIND=JPRD) :: BETA
145-
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
146-
INTEGER(KIND=JPIM) :: LDC
147-
INTEGER(KIND=JPIM) :: STRIDEC
148-
INTEGER(KIND=JPIM) :: BATCHCOUNT
149-
INTEGER(KIND=C_INT) :: STREAM
150-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
151-
152-
INTEGER(KIND=C_LONG) :: HIP_STREAM
153-
154-
#ifdef ACCGPU
155-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
156-
#endif
157-
#ifdef OMPGPU
158-
#endif
159-
160-
#if defined(_CRAYFTN)
161-
#ifdef ACCGPU
162-
!$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
163-
#endif
164-
#endif
165-
CALL HIP_DGEMM_BATCHED( &
166-
& TRANSA, TRANSB, &
167-
& M, N, K, &
168-
& ALPHA, &
169-
& AARRAY, LDA, STRIDEA, &
170-
& BARRAY, LDB, STRIDEB, &
171-
& BETA, &
172-
& CARRAY, LDC, STRIDEC, &
173-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
174-
#if defined(_CRAYFTN)
175-
#ifdef ACCGPU
176-
!$ACC END HOST_DATA
177-
#endif
178-
#endif
179-
END SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD
180-
181-
SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD( &
182-
& TRANSA, TRANSB, &
183-
& M, N, K, &
184-
& ALPHA, &
185-
& AARRAY, LDA, STRIDEA, &
186-
& BARRAY, LDB, STRIDEB, &
187-
& BETA, &
188-
& CARRAY, LDC, STRIDEC, &
189-
& BATCHCOUNT, STREAM, ALLOC)
190-
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
191-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
192-
INTEGER(KIND=JPIM) :: M
193-
INTEGER(KIND=JPIM) :: N
194-
INTEGER(KIND=JPIM) :: K
195-
REAL(KIND=JPRM) :: ALPHA
196-
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
197-
INTEGER(KIND=JPIM) :: LDA
198-
INTEGER(KIND=JPIM) :: STRIDEA
199-
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
200-
INTEGER(KIND=JPIM) :: LDB
201-
INTEGER(KIND=JPIM) :: STRIDEB
202-
REAL(KIND=JPRM) :: BETA
203-
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
204-
INTEGER(KIND=JPIM) :: LDC
205-
INTEGER(KIND=JPIM) :: STRIDEC
206-
INTEGER(KIND=JPIM) :: BATCHCOUNT
207-
INTEGER(KIND=C_INT) :: STREAM
208-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
209-
210-
INTEGER(KIND=C_LONG) :: HIP_STREAM
211-
212-
#ifdef ACCGPU
213-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
214-
#endif
215-
#ifdef OMPGPU
216-
#endif
217-
218-
CALL HIP_SGEMM_BATCHED( &
219-
& TRANSA, TRANSB, &
220-
& M, N, K, &
221-
& ALPHA, &
222-
& AARRAY, LDA, STRIDEA, &
223-
& BARRAY, LDB, STRIDEB, &
224-
& BETA, &
225-
& CARRAY, LDC, STRIDEC, &
226-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
227-
END SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD
228-
229-
SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( &
230-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
231-
& M, N, K, &
232-
& ALPHA, &
233-
& AARRAY, LDA, OFFSETA, &
234-
& BARRAY, LDB, OFFSETB, &
235-
& BETA, &
236-
& CARRAY, LDC, OFFSETC, &
237-
& BATCHCOUNT, STREAM, ALLOC)
238-
USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
239-
INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID
240-
INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
241-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
242-
INTEGER(KIND=JPIM) :: M
243-
INTEGER(KIND=JPIM) :: N(:)
244-
INTEGER(KIND=JPIM) :: K(:)
245-
REAL(KIND=JPRD) :: ALPHA
246-
REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
247-
INTEGER(KIND=JPIM) :: LDA
248-
INTEGER(KIND=JPIB) :: OFFSETA(:)
249-
REAL(KIND=JPRD), DIMENSION(*) :: BARRAY
250-
INTEGER(KIND=JPIM) :: LDB(:)
251-
INTEGER(KIND=JPIB) :: OFFSETB(:)
252-
REAL(KIND=JPRD) :: BETA
253-
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
254-
INTEGER(KIND=JPIM) :: LDC
255-
INTEGER(KIND=JPIB) :: OFFSETC(:)
256-
INTEGER(KIND=JPIM) :: BATCHCOUNT
257-
INTEGER(KIND=C_INT) :: STREAM
258-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
259-
260-
INTEGER(KIND=C_LONG) :: HIP_STREAM
261-
262-
#ifdef ACCGPU
263-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
264-
#endif
265-
#ifdef OMPGPU
266-
#endif
267-
268-
CALL HIP_DGEMM_GROUPED( &
269-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
270-
& M, N, K, &
271-
& ALPHA, &
272-
& AARRAY, LDA, OFFSETA, &
273-
& BARRAY, LDB, OFFSETB, &
274-
& BETA, &
275-
& CARRAY, LDC, OFFSETC, &
276-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
277-
278-
END SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD
279-
280-
SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(&
281-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
282-
& M, N, K, &
283-
& ALPHA, &
284-
& AARRAY, LDA, OFFSETA, &
285-
& BARRAY, LDB, OFFSETB, &
286-
& BETA, &
287-
& CARRAY, LDC, OFFSETC, &
288-
& BATCHCOUNT, STREAM, ALLOC)
289-
USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
290-
INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID
291-
INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
292-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
293-
INTEGER(KIND=JPIM) :: M
294-
INTEGER(KIND=JPIM) :: N(:)
295-
INTEGER(KIND=JPIM) :: K(:)
296-
REAL(KIND=JPRM) :: ALPHA
297-
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
298-
INTEGER(KIND=JPIM) :: LDA
299-
INTEGER(KIND=JPIB) :: OFFSETA(:)
300-
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
301-
INTEGER(KIND=JPIM) :: LDB(:)
302-
INTEGER(KIND=JPIB) :: OFFSETB(:)
303-
REAL(KIND=JPRM) :: BETA
304-
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
305-
INTEGER(KIND=JPIM) :: LDC
306-
INTEGER(KIND=JPIB) :: OFFSETC(:)
307-
INTEGER(KIND=JPIM) :: BATCHCOUNT
308-
INTEGER(KIND=C_INT) :: STREAM
309-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
310-
311-
INTEGER(KIND=C_LONG) :: HIP_STREAM
312-
313-
#ifdef ACCGPU
314-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
315-
#endif
316-
#ifdef OMPGPU
317-
#endif
318-
319-
#if defined(_CRAYFTN)
320-
#ifdef ACCGPU
321-
!$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
322-
#endif
323-
#endif
324-
CALL HIP_SGEMM_GROUPED( &
325-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
326-
& M, N, K, &
327-
& ALPHA, &
328-
& AARRAY, LDA, OFFSETA, &
329-
& BARRAY, LDB, OFFSETB, &
330-
& BETA, &
331-
& CARRAY, LDC, OFFSETC, &
332-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
333-
#if defined(_CRAYFTN)
334-
#ifdef ACCGPU
335-
!$ACC END HOST_DATA
336-
#endif
337-
#endif
338-
339-
END SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD
340-
341108
END MODULE HICBLAS_MOD

src/trans/gpu/internal/ledir_mod.F90

+23-11
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,19 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
106106
USE TPM_GEOMETRY, ONLY: G
107107
USE TPM_FIELDS_GPU, ONLY: FG
108108
USE TPM_DISTR, ONLY: D
109-
USE HICBLAS_MOD, ONLY: HIP_DGEMM_BATCHED_OVERLOAD, &
110-
& HIP_DGEMM_GROUPED_OVERLOAD, HIP_SGEMM_GROUPED_OVERLOAD
109+
USE HICBLAS_MOD, ONLY: HIP_DGEMM_BATCHED, &
110+
& HIP_DGEMM_GROUPED, HIP_SGEMM_GROUPED
111111
USE MPL_MODULE, ONLY: MPL_BARRIER,MPL_ALL_MS_COMM
112112
USE TPM_STATS, ONLY: GSTATS => GSTATS_NVTX
113-
USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_INT
113+
USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_INT, C_LONG, C_LOC
114+
#ifdef ACCGPU
115+
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM
116+
#endif
114117

115118
#ifdef TRANS_SINGLE
116-
#define HIP_GEMM HIP_SGEMM_GROUPED_OVERLOAD
119+
#define HIP_GEMM HIP_SGEMM_GROUPED
117120
#else
118-
#define HIP_GEMM HIP_DGEMM_GROUPED_OVERLOAD
121+
#define HIP_GEMM HIP_DGEMM_GROUPED
119122
#endif
120123

121124
IMPLICIT NONE
@@ -149,12 +152,21 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
149152
INTEGER(KIND=JPIM) :: IIN0_STRIDES0, IIN0_STRIDES1
150153
INTEGER(KIND=8) :: ALLOC_SZ, ALLOC_POS
151154

155+
INTEGER(KIND=C_LONG) :: HIP_STREAM
156+
152157
ASSOCIATE(D_NUMP=>D%NUMP, R_NSMAX=>R%NSMAX, R_NTMAX=>R%NTMAX, G_NDGLU=>G%NDGLU, &
153158
& D_MYMS=>D%MYMS, D_OFFSETS_GEMM1=>D%OFFSETS_GEMM1, &
154159
& D_OFFSETS_GEMM2=>D%OFFSETS_GEMM2, &
155160
& ZAA=>FG%ZAA, ZAS=>FG%ZAS, ZAA0=>FG%ZAA0, ZAS0=>FG%ZAS0)
156161
IF (LHOOK) CALL DR_HOOK('LE_DGEMM',0,ZHOOK_HANDLE)
157162

163+
#ifdef ACCGPU
164+
HIP_STREAM = INT(ACC_GET_HIP_STREAM(1_C_INT), C_LONG)
165+
#endif
166+
#ifdef OMPGPU
167+
HIP_STREAM = 1_C_LONG
168+
#endif
169+
158170
CALL LEDIR_STRIDES(KF_FS,IOUT_STRIDES0,IOUT_STRIDES1,IIN_STRIDES0,IIN_STRIDES1,&
159171
IOUT0_STRIDES0,IOUT0_STRIDES1,IIN0_STRIDES0,IIN0_STRIDES1)
160172

@@ -187,15 +199,15 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
187199
#ifdef ACCGPU
188200
!$ACC HOST_DATA USE_DEVICE(ZAA0,ZINPA0,ZOUT0)
189201
#endif
190-
CALL HIP_DGEMM_BATCHED_OVERLOAD( &
202+
CALL HIP_DGEMM_BATCHED( &
191203
& 'N', 'N', &
192204
& KF_FS, (R_NSMAX+2)/2, G_NDGLU(0), &
193205
& 1.0_JPRD, &
194206
& ZINPA0, IIN0_STRIDES0, 0, &
195207
& ZAA0, SIZE(ZAA0,1), 0, &
196208
& 0.0_JPRD, &
197209
& ZOUT0, IOUT0_STRIDES0, 0, &
198-
& 1, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
210+
& 1, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
199211
#ifdef OMPGPU
200212
!$OMP END TARGET DATA
201213
#endif
@@ -233,7 +245,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
233245
& ZAA, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
234246
& 0.0_JPRBT, &
235247
& ZOUT, IOUT_STRIDES0, COFFSETS, &
236-
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
248+
& D_NUMP, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
237249
#ifdef OMPGPU
238250
!$OMP END TARGET DATA
239251
#endif
@@ -306,15 +318,15 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
306318
!$ACC HOST_DATA USE_DEVICE(ZAS0,ZINPS0,ZOUT0)
307319
#endif
308320
! compute m=0 in double precision:
309-
call HIP_DGEMM_BATCHED_OVERLOAD( &
321+
call HIP_DGEMM_BATCHED( &
310322
& 'N', 'N', &
311323
& KF_FS, (R_NSMAX+3)/2, G_NDGLU(0), &
312324
& 1.0_JPRD, &
313325
& ZINPS0, IIN0_STRIDES0, 0, &
314326
& ZAS0, SIZE(ZAS0,1), 0, &
315327
& 0.0_JPRD, &
316328
& ZOUT0, IOUT0_STRIDES0, 0, &
317-
& 1, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
329+
& 1, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
318330
#ifdef OMPGPU
319331
!$OMP END TARGET DATA
320332
#endif
@@ -353,7 +365,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
353365
& ZAS, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
354366
& 0.0_JPRBT, &
355367
& ZOUT, IOUT_STRIDES0, COFFSETS, &
356-
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
368+
& D_NUMP, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
357369
#ifdef OMPGPU
358370
!$OMP END TARGET DATA
359371
#endif

0 commit comments

Comments
 (0)