Skip to content

Commit 7a9606b

Browse files
committed
Remove _OVERLOAD layer
1 parent 9d5d769 commit 7a9606b

File tree

3 files changed

+56
-255
lines changed

3 files changed

+56
-255
lines changed

src/trans/gpu/algor/hicblas_mod.F90

-233
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,8 @@
77
! nor does it submit to any jurisdiction.
88
!
99

10-
#if defined CUDAGPU
11-
#define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM
12-
#define OPENACC_LIB OPENACC
13-
#endif
14-
1510
MODULE HICBLAS_MOD
1611

17-
USE EC_PARKIND, ONLY: JPIM, JPRM, JPRD, JPIB
18-
USE GROWING_ALLOCATOR_MOD, ONLY: GROWING_ALLOCATION_TYPE
19-
#ifdef ACCGPU
20-
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM
21-
#endif
22-
#ifdef OMPGPU
23-
#endif
24-
2512
IMPLICIT NONE
2613

2714
INTERFACE
@@ -118,224 +105,4 @@ SUBROUTINE HIP_SGEMM_GROUPED( &
118105
END SUBROUTINE HIP_SGEMM_GROUPED
119106
END INTERFACE
120107

121-
CONTAINS
122-
123-
SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD( &
124-
& TRANSA, TRANSB, &
125-
& M, N, K, &
126-
& ALPHA, &
127-
& AARRAY, LDA, STRIDEA, &
128-
& BARRAY, LDB, STRIDEB, &
129-
& BETA, &
130-
& CARRAY, LDC, STRIDEC, &
131-
& BATCHCOUNT, STREAM, ALLOC)
132-
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
133-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
134-
INTEGER(KIND=JPIM) :: M
135-
INTEGER(KIND=JPIM) :: N
136-
INTEGER(KIND=JPIM) :: K
137-
REAL(KIND=JPRD) :: ALPHA
138-
REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
139-
INTEGER(KIND=JPIM) :: LDA
140-
INTEGER(KIND=JPIM) :: STRIDEA
141-
REAL(KIND=JPRD), DIMENSION(:,:) :: BARRAY
142-
INTEGER(KIND=JPIM) :: LDB
143-
INTEGER(KIND=JPIM) :: STRIDEB
144-
REAL(KIND=JPRD) :: BETA
145-
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
146-
INTEGER(KIND=JPIM) :: LDC
147-
INTEGER(KIND=JPIM) :: STRIDEC
148-
INTEGER(KIND=JPIM) :: BATCHCOUNT
149-
INTEGER(KIND=C_INT) :: STREAM
150-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
151-
152-
INTEGER(KIND=C_LONG) :: HIP_STREAM
153-
154-
#ifdef ACCGPU
155-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
156-
#endif
157-
#ifdef OMPGPU
158-
#endif
159-
160-
#if defined(_CRAYFTN)
161-
#ifdef ACCGPU
162-
!$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
163-
#endif
164-
#endif
165-
CALL HIP_DGEMM_BATCHED( &
166-
& TRANSA, TRANSB, &
167-
& M, N, K, &
168-
& ALPHA, &
169-
& AARRAY, LDA, STRIDEA, &
170-
& BARRAY, LDB, STRIDEB, &
171-
& BETA, &
172-
& CARRAY, LDC, STRIDEC, &
173-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
174-
#if defined(_CRAYFTN)
175-
#ifdef ACCGPU
176-
!$ACC END HOST_DATA
177-
#endif
178-
#endif
179-
END SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD
180-
181-
SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD( &
182-
& TRANSA, TRANSB, &
183-
& M, N, K, &
184-
& ALPHA, &
185-
& AARRAY, LDA, STRIDEA, &
186-
& BARRAY, LDB, STRIDEB, &
187-
& BETA, &
188-
& CARRAY, LDC, STRIDEC, &
189-
& BATCHCOUNT, STREAM, ALLOC)
190-
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
191-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
192-
INTEGER(KIND=JPIM) :: M
193-
INTEGER(KIND=JPIM) :: N
194-
INTEGER(KIND=JPIM) :: K
195-
REAL(KIND=JPRM) :: ALPHA
196-
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
197-
INTEGER(KIND=JPIM) :: LDA
198-
INTEGER(KIND=JPIM) :: STRIDEA
199-
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
200-
INTEGER(KIND=JPIM) :: LDB
201-
INTEGER(KIND=JPIM) :: STRIDEB
202-
REAL(KIND=JPRM) :: BETA
203-
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
204-
INTEGER(KIND=JPIM) :: LDC
205-
INTEGER(KIND=JPIM) :: STRIDEC
206-
INTEGER(KIND=JPIM) :: BATCHCOUNT
207-
INTEGER(KIND=C_INT) :: STREAM
208-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
209-
210-
INTEGER(KIND=C_LONG) :: HIP_STREAM
211-
212-
#ifdef ACCGPU
213-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
214-
#endif
215-
#ifdef OMPGPU
216-
#endif
217-
218-
CALL HIP_SGEMM_BATCHED( &
219-
& TRANSA, TRANSB, &
220-
& M, N, K, &
221-
& ALPHA, &
222-
& AARRAY, LDA, STRIDEA, &
223-
& BARRAY, LDB, STRIDEB, &
224-
& BETA, &
225-
& CARRAY, LDC, STRIDEC, &
226-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
227-
END SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD
228-
229-
SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( &
230-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
231-
& M, N, K, &
232-
& ALPHA, &
233-
& AARRAY, LDA, OFFSETA, &
234-
& BARRAY, LDB, OFFSETB, &
235-
& BETA, &
236-
& CARRAY, LDC, OFFSETC, &
237-
& BATCHCOUNT, STREAM, ALLOC)
238-
USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
239-
INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID
240-
INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
241-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
242-
INTEGER(KIND=JPIM) :: M
243-
INTEGER(KIND=JPIM) :: N(:)
244-
INTEGER(KIND=JPIM) :: K(:)
245-
REAL(KIND=JPRD) :: ALPHA
246-
REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
247-
INTEGER(KIND=JPIM) :: LDA
248-
INTEGER(KIND=JPIB) :: OFFSETA(:)
249-
REAL(KIND=JPRD), DIMENSION(*) :: BARRAY
250-
INTEGER(KIND=JPIM) :: LDB(:)
251-
INTEGER(KIND=JPIB) :: OFFSETB(:)
252-
REAL(KIND=JPRD) :: BETA
253-
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
254-
INTEGER(KIND=JPIM) :: LDC
255-
INTEGER(KIND=JPIB) :: OFFSETC(:)
256-
INTEGER(KIND=JPIM) :: BATCHCOUNT
257-
INTEGER(KIND=C_INT) :: STREAM
258-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
259-
260-
INTEGER(KIND=C_LONG) :: HIP_STREAM
261-
262-
#ifdef ACCGPU
263-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
264-
#endif
265-
#ifdef OMPGPU
266-
#endif
267-
268-
CALL HIP_DGEMM_GROUPED( &
269-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
270-
& M, N, K, &
271-
& ALPHA, &
272-
& AARRAY, LDA, OFFSETA, &
273-
& BARRAY, LDB, OFFSETB, &
274-
& BETA, &
275-
& CARRAY, LDC, OFFSETC, &
276-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
277-
278-
END SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD
279-
280-
SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(&
281-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
282-
& M, N, K, &
283-
& ALPHA, &
284-
& AARRAY, LDA, OFFSETA, &
285-
& BARRAY, LDB, OFFSETB, &
286-
& BETA, &
287-
& CARRAY, LDC, OFFSETC, &
288-
& BATCHCOUNT, STREAM, ALLOC)
289-
USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
290-
INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID
291-
INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
292-
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
293-
INTEGER(KIND=JPIM) :: M
294-
INTEGER(KIND=JPIM) :: N(:)
295-
INTEGER(KIND=JPIM) :: K(:)
296-
REAL(KIND=JPRM) :: ALPHA
297-
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
298-
INTEGER(KIND=JPIM) :: LDA
299-
INTEGER(KIND=JPIB) :: OFFSETA(:)
300-
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
301-
INTEGER(KIND=JPIM) :: LDB(:)
302-
INTEGER(KIND=JPIB) :: OFFSETB(:)
303-
REAL(KIND=JPRM) :: BETA
304-
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
305-
INTEGER(KIND=JPIM) :: LDC
306-
INTEGER(KIND=JPIB) :: OFFSETC(:)
307-
INTEGER(KIND=JPIM) :: BATCHCOUNT
308-
INTEGER(KIND=C_INT) :: STREAM
309-
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC
310-
311-
INTEGER(KIND=C_LONG) :: HIP_STREAM
312-
313-
#ifdef ACCGPU
314-
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
315-
#endif
316-
#ifdef OMPGPU
317-
#endif
318-
319-
#if defined(_CRAYFTN)
320-
#ifdef ACCGPU
321-
!$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
322-
#endif
323-
#endif
324-
CALL HIP_SGEMM_GROUPED( &
325-
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
326-
& M, N, K, &
327-
& ALPHA, &
328-
& AARRAY, LDA, OFFSETA, &
329-
& BARRAY, LDB, OFFSETB, &
330-
& BETA, &
331-
& CARRAY, LDC, OFFSETC, &
332-
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
333-
#if defined(_CRAYFTN)
334-
#ifdef ACCGPU
335-
!$ACC END HOST_DATA
336-
#endif
337-
#endif
338-
339-
END SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD
340-
341108
END MODULE HICBLAS_MOD

src/trans/gpu/internal/ledir_mod.F90

+28-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
#define ALIGN(I, A) (((I)+(A)-1)/(A)*(A))
2+
#if defined CUDAGPU
3+
#define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM
4+
#define OPENACC_LIB OPENACC
5+
#endif
6+
27
! (C) Copyright 2000- ECMWF.
38
! (C) Copyright 2000- Meteo-France.
49
! (C) Copyright 2022- NVIDIA.
@@ -106,16 +111,19 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
106111
USE TPM_GEOMETRY, ONLY: G
107112
USE TPM_FIELDS_GPU, ONLY: FG
108113
USE TPM_DISTR, ONLY: D
109-
USE HICBLAS_MOD, ONLY: HIP_DGEMM_BATCHED_OVERLOAD, &
110-
& HIP_DGEMM_GROUPED_OVERLOAD, HIP_SGEMM_GROUPED_OVERLOAD
114+
USE HICBLAS_MOD, ONLY: HIP_DGEMM_BATCHED, &
115+
& HIP_DGEMM_GROUPED, HIP_SGEMM_GROUPED
111116
USE MPL_MODULE, ONLY: MPL_BARRIER,MPL_ALL_MS_COMM
112117
USE TPM_STATS, ONLY: GSTATS => GSTATS_NVTX
113-
USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_INT
118+
USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_INT, C_LONG, C_LOC
119+
#ifdef ACCGPU
120+
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM
121+
#endif
114122

115123
#ifdef TRANS_SINGLE
116-
#define HIP_GEMM HIP_SGEMM_GROUPED_OVERLOAD
124+
#define HIP_GEMM HIP_SGEMM_GROUPED
117125
#else
118-
#define HIP_GEMM HIP_DGEMM_GROUPED_OVERLOAD
126+
#define HIP_GEMM HIP_DGEMM_GROUPED
119127
#endif
120128

121129
IMPLICIT NONE
@@ -149,12 +157,21 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
149157
INTEGER(KIND=JPIM) :: IIN0_STRIDES0, IIN0_STRIDES1
150158
INTEGER(KIND=8) :: ALLOC_SZ, ALLOC_POS
151159

160+
INTEGER(KIND=C_LONG) :: HIP_STREAM
161+
152162
ASSOCIATE(D_NUMP=>D%NUMP, R_NSMAX=>R%NSMAX, R_NTMAX=>R%NTMAX, G_NDGLU=>G%NDGLU, &
153163
& D_MYMS=>D%MYMS, D_OFFSETS_GEMM1=>D%OFFSETS_GEMM1, &
154164
& D_OFFSETS_GEMM2=>D%OFFSETS_GEMM2, &
155165
& ZAA=>FG%ZAA, ZAS=>FG%ZAS, ZAA0=>FG%ZAA0, ZAS0=>FG%ZAS0)
156166
IF (LHOOK) CALL DR_HOOK('LE_DGEMM',0,ZHOOK_HANDLE)
157167

168+
#ifdef ACCGPU
169+
HIP_STREAM = INT(ACC_GET_HIP_STREAM(1_C_INT), C_LONG)
170+
#endif
171+
#ifdef OMPGPU
172+
HIP_STREAM = 1_C_LONG
173+
#endif
174+
158175
CALL LEDIR_STRIDES(KF_FS,IOUT_STRIDES0,IOUT_STRIDES1,IIN_STRIDES0,IIN_STRIDES1,&
159176
IOUT0_STRIDES0,IOUT0_STRIDES1,IIN0_STRIDES0,IIN0_STRIDES1)
160177

@@ -187,15 +204,15 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
187204
#ifdef ACCGPU
188205
!$ACC HOST_DATA USE_DEVICE(ZAA0,ZINPA0,ZOUT0)
189206
#endif
190-
CALL HIP_DGEMM_BATCHED_OVERLOAD( &
207+
CALL HIP_DGEMM_BATCHED( &
191208
& 'N', 'N', &
192209
& KF_FS, (R_NSMAX+2)/2, G_NDGLU(0), &
193210
& 1.0_JPRD, &
194211
& ZINPA0, IIN0_STRIDES0, 0, &
195212
& ZAA0, SIZE(ZAA0,1), 0, &
196213
& 0.0_JPRD, &
197214
& ZOUT0, IOUT0_STRIDES0, 0, &
198-
& 1, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
215+
& 1, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
199216
#ifdef OMPGPU
200217
!$OMP END TARGET DATA
201218
#endif
@@ -233,7 +250,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
233250
& ZAA, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
234251
& 0.0_JPRBT, &
235252
& ZOUT, IOUT_STRIDES0, COFFSETS, &
236-
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
253+
& D_NUMP, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
237254
#ifdef OMPGPU
238255
!$OMP END TARGET DATA
239256
#endif
@@ -306,15 +323,15 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
306323
!$ACC HOST_DATA USE_DEVICE(ZAS0,ZINPS0,ZOUT0)
307324
#endif
308325
! compute m=0 in double precision:
309-
call HIP_DGEMM_BATCHED_OVERLOAD( &
326+
call HIP_DGEMM_BATCHED( &
310327
& 'N', 'N', &
311328
& KF_FS, (R_NSMAX+3)/2, G_NDGLU(0), &
312329
& 1.0_JPRD, &
313330
& ZINPS0, IIN0_STRIDES0, 0, &
314331
& ZAS0, SIZE(ZAS0,1), 0, &
315332
& 0.0_JPRD, &
316333
& ZOUT0, IOUT0_STRIDES0, 0, &
317-
& 1, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
334+
& 1, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
318335
#ifdef OMPGPU
319336
!$OMP END TARGET DATA
320337
#endif
@@ -353,7 +370,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
353370
& ZAS, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
354371
& 0.0_JPRBT, &
355372
& ZOUT, IOUT_STRIDES0, COFFSETS, &
356-
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
373+
& D_NUMP, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
357374
#ifdef OMPGPU
358375
!$OMP END TARGET DATA
359376
#endif

0 commit comments

Comments
 (0)