Skip to content

Commit 2a87188

Browse files
committed
compact legendre polynomials
1 parent 094eca2 commit 2a87188

File tree

9 files changed

+111
-94
lines changed

9 files changed

+111
-94
lines changed

src/trans/common/internal/tpm_distr.F90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ MODULE TPM_DISTR
171171
REAL(KIND=JPRD) ,ALLOCATABLE :: RWEIGHT(:) ! Weight per grid-point (if weighted distribution)
172172
INTEGER(KIND=JPIM) ,ALLOCATABLE :: NPROCA_GP(:) ! Number of grid-points per a-set
173173

174-
INTEGER(KIND=JPIB), ALLOCATABLE :: OFFSETS_GEMM1(:), OFFSETS_GEMM2(:)
174+
INTEGER(KIND=JPIB), ALLOCATABLE :: OFFSETS_GEMM1(:), OFFSETS_GEMM2(:), OFFSETS_GEMM_MATRIX(:)
175+
INTEGER(KIND=JPIM), ALLOCATABLE :: LEGENDRE_MATRIX_STRIDES(:)
175176

176177
END TYPE DISTR_TYPE
177178

src/trans/gpu/algor/hicblas_cutlass.cuda.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ class cutlass_sgemm_grouped<CutlassType::cutlass_fp32, TransA, TransB> {
154154

155155
} // namespace detail
156156
template <cublasOperation_t TransA, cublasOperation_t TransB>
157-
void cutlass_sgemm_wrapper_grouped_op(int resol_id, int blas_id, int m, int *n, int *k,
157+
void cutlass_sgemm_wrapper_grouped_op(int resol_id, int blas_id, int m, const int *n, const int *k,
158158
float alpha, const float *A, int lda,
159-
int64_t *offsetsA, const float *B, int ldb,
160-
int64_t *offsetsB, float beta, float *C,
161-
int ldc, int64_t *offsetsC, int batchCount,
159+
const int64_t *offsetsA, const float *B, const int *ldb,
160+
const int64_t *offsetsB, float beta, float *C,
161+
int ldc, const int64_t *offsetsC, int batchCount,
162162
cudaStream_t stream,
163163
void *growing_allocator) {
164164
using namespace detail;
@@ -182,10 +182,10 @@ void cutlass_sgemm_wrapper_grouped_op(int resol_id, int blas_id, int m, int *n,
182182
}
183183

184184
void cutlass_sgemm_wrapper_grouped(int resol_id, int blas_id, char transa, char transb,
185-
int m, int *n, int *k, float alpha,
186-
const float *A, int lda, int64_t *offsetsA,
187-
const float *B, int ldb, int64_t *offsetsB, float beta,
188-
float *C, int ldc, int64_t *offsetsC,
185+
int m, const int *n, const int *k, float alpha,
186+
const float *A, int lda, const int64_t *offsetsA,
187+
const float *B, const int *ldb, const int64_t *offsetsB, float beta,
188+
float *C, int ldc, const int64_t *offsetsC,
189189
int batchCount, cudaStream_t stream,
190190
void *growing_allocator) {
191191

src/trans/gpu/algor/hicblas_gemm.hip.cpp

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ template <typename Gemm> void erase_from_caches(int resol_id) {
8989

9090
// this version is using graphs and caches the graphs
9191
template <typename Gemm, typename Real>
92-
void run_group_graph(Gemm &&gemm, int resol_id, int m, int *n, int *k,
93-
Real alpha, const Real *A, int lda, int64_t *offsetsA,
94-
const Real *B, int ldb, int64_t *offsetsB, Real beta,
95-
Real *C, int ldc, int64_t *offsetsC, int batchCount,
92+
void run_group_graph(Gemm &&gemm, int resol_id, int m, const int *n,
93+
const int *k, Real alpha, const Real *A, int lda,
94+
const int64_t *offsetsA, const Real *B, const int *ldb,
95+
const int64_t *offsetsB, Real beta, Real *C, int ldc,
96+
const int64_t *offsetsC, int batchCount,
9697
hipStream_t stream, int blas_id, void *growing_allocator) {
9798
growing_allocator_register_free_c(growing_allocator,
9899
free_gemm_graph_cache<Gemm>);
@@ -138,7 +139,7 @@ void run_group_graph(Gemm &&gemm, int resol_id, int m, int *n, int *k,
138139

139140
HIC_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
140141
gemm(stream, m, n[i], k[i], alpha, A + offsetsA[i], lda, B + offsetsB[i],
141-
ldb, beta, C + offsetsC[i], ldc);
142+
ldb[i], beta, C + offsetsC[i], ldc);
142143
hipGraph_t my_graph;
143144
HIC_CHECK(hipStreamEndCapture(stream, &my_graph));
144145
hipGraphNode_t my_node;
@@ -163,16 +164,16 @@ void run_group_graph(Gemm &&gemm, int resol_id, int m, int *n, int *k,
163164

164165
// stupid simple gemm calls
165166
template <typename Gemm, typename Real>
166-
void run_group(Gemm &&gemm, int resol_id, int m, int *n, int *k, Real alpha,
167-
const Real *A, int lda, int64_t *offsetsA, const Real *B,
168-
int ldb, int64_t *offsetsB, Real beta, Real *C, int ldc,
169-
int64_t *offsetsC, int batchCount, hipStream_t stream,
170-
int = -1) {
167+
void run_group(Gemm &&gemm, int resol_id, int m, const int *n, const int *k,
168+
Real alpha, const Real *A, int lda, const int64_t *offsetsA,
169+
const Real *B, const int *ldb, const int64_t *offsetsB,
170+
Real beta, Real *C, int ldc, const int64_t *offsetsC,
171+
int batchCount, hipStream_t stream, int = -1) {
171172
for (int i = 0; i < batchCount; ++i) {
172173
if (m == 0 || n[i] == 0 || k[i] == 0)
173174
continue;
174175
gemm(stream, m, n[i], k[i], alpha, A + offsetsA[i], lda, B + offsetsB[i],
175-
ldb, beta, C + offsetsC[i], ldc);
176+
ldb[i], beta, C + offsetsC[i], ldc);
176177
}
177178
}
178179

@@ -215,14 +216,12 @@ template <typename Real> struct hipblas_gemm_grouped {
215216

216217
#ifndef USE_CUTLASS
217218

218-
void hipblas_sgemm_wrapper_grouped(int resol_id, int blas_id, char transa,
219-
char transb, int m, int *n, int *k,
220-
float alpha, const float *A, int lda,
221-
int64_t *offsetsA, const float *B, int ldb,
222-
int64_t *offsetsB, float beta, float *C,
223-
int ldc, int64_t *offsetsC, int batchCount,
224-
hipStream_t stream,
225-
void *growing_allocator) {
219+
void hipblas_sgemm_wrapper_grouped(
220+
int resol_id, int blas_id, char transa, char transb, int m, const int *n,
221+
const int *k, float alpha, const float *A, int lda, const int64_t *offsetsA,
222+
const float *B, const int *ldb, const int64_t *offsetsB, float beta,
223+
float *C, int ldc, const int64_t *offsetsC, int batchCount,
224+
hipStream_t stream, void *growing_allocator) {
226225

227226
hipblasOperation_t op_t1 = HIPBLAS_OP_N, op_t2 = HIPBLAS_OP_N;
228227
if (transa == 'T' || transa == 't')
@@ -244,12 +243,13 @@ void hipblas_sgemm_wrapper_grouped(int resol_id, int blas_id, char transa,
244243
#endif
245244

246245
void hipblas_dgemm_wrapper_grouped(int resol_id, int blas_id, char transa,
247-
char transb, int m, int *n, int *k,
248-
double alpha, const double *A, int lda,
249-
int64_t *offsetsA, const double *B, int ldb,
250-
int64_t *offsetsB, double beta, double *C,
251-
int ldc, int64_t *offsetsC, int batchCount,
252-
hipStream_t stream, void *) {
246+
char transb, int m, const int *n,
247+
const int *k, double alpha, const double *A,
248+
int lda, const int64_t *offsetsA,
249+
double const *B, const int *ldb,
250+
const int64_t *offsetsB, double beta,
251+
double *C, int ldc, const int64_t *offsetsC,
252+
int batchCount, hipStream_t stream, void *) {
253253

254254
hipblasOperation_t op_t1 = HIPBLAS_OP_N, op_t2 = HIPBLAS_OP_N;
255255
if (transa == 'T' || transa == 't')
@@ -313,13 +313,12 @@ void hipblas_sgemm_wrapper(char transa, char transb, int m, int n, int k,
313313
batchCount));
314314
}
315315

316-
void hipblas_sgemm_wrapper_grouped(int resol_id, int blas_id, char transa,
317-
char transb, int m, int *n, int *k,
318-
float alpha, const float *A, int lda,
319-
int64_t *offsetsA, const float *B, int ldb,
320-
int64_t *offsetsB, float beta, float *C,
321-
int ldc, int64_t *offsetsC, int batchCount,
322-
size_t stream, void *growing_allocator) {
316+
void hipblas_sgemm_wrapper_grouped(
317+
int resol_id, int blas_id, char transa, char transb, int m, const int *n,
318+
const int *k, float alpha, const float *A, int lda, const int64_t *offsetsA,
319+
const float *B, const int *ldb, const int64_t *offsetsB, float beta,
320+
float *C, int ldc, const int64_t *offsetsC, int batchCount, size_t stream,
321+
void *growing_allocator) {
323322
#ifdef USE_CUTLASS
324323
cutlass_sgemm_wrapper_grouped(resol_id, blas_id, transa, transb, m, n, k,
325324
alpha, A, lda, offsetsA, B, ldb, offsetsB, beta,
@@ -334,12 +333,14 @@ void hipblas_sgemm_wrapper_grouped(int resol_id, int blas_id, char transa,
334333
}
335334

336335
void hipblas_dgemm_wrapper_grouped(int resol_id, int blas_id, char transa,
337-
char transb, int m, int *n, int *k,
338-
double alpha, const double *A, int lda,
339-
int64_t *offsetsA, const double *B, int ldb,
340-
int64_t *offsetsB, double beta, double *C,
341-
int ldc, int64_t *offsetsC, int batchCount,
342-
size_t stream, void *growing_allocator) {
336+
char transb, int m, const int *n,
337+
const int *k, double alpha, double const *A,
338+
int lda, const int64_t *offsetsA,
339+
double const *B, const int *ldb,
340+
const int64_t *offsetsB, double beta,
341+
double *C, int ldc, const int64_t *offsetsC,
342+
int batchCount, size_t stream,
343+
void *growing_allocator) {
343344
hipblas_dgemm_wrapper_grouped(resol_id, blas_id, transa, transb, m, n, k,
344345
alpha, A, lda, offsetsA, B, ldb, offsetsB, beta,
345346
C, ldc, offsetsC, batchCount,

src/trans/gpu/algor/hicblas_mod.F90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ SUBROUTINE HIP_DGEMM_GROUPED( &
8383
&) BIND(C, NAME='hipblas_dgemm_wrapper_grouped')
8484
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_DOUBLE, C_SIZE_T, C_PTR, C_INT64_T
8585
CHARACTER(1,C_CHAR), VALUE :: CTA, CTB
86-
INTEGER(C_INT), VALUE :: RESOL_ID, BLAS_ID, M, LDA, LDB, LDC, BATCHCOUNT
87-
INTEGER(C_INT) :: N(*), K(*)
86+
INTEGER(C_INT), VALUE :: RESOL_ID, BLAS_ID, M, LDA, LDC, BATCHCOUNT
87+
INTEGER(C_INT) :: N(*), K(*), LDB(*)
8888
INTEGER(C_INT64_T) :: OFFSETA(*), OFFSETB(*), OFFSETC(*)
8989
REAL(C_DOUBLE), VALUE :: ALPHA,BETA
9090
REAL(C_DOUBLE) :: A(*), B(*), C(*)
@@ -104,8 +104,8 @@ SUBROUTINE HIP_SGEMM_GROUPED( &
104104
&) BIND(C, NAME='hipblas_sgemm_wrapper_grouped')
105105
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_FLOAT, C_SIZE_T, C_PTR, C_INT64_T
106106
CHARACTER(1,C_CHAR), VALUE :: CTA, CTB
107-
INTEGER(C_INT), VALUE :: RESOL_ID, BLAS_ID, M, LDA, LDB, LDC, BATCHCOUNT
108-
INTEGER(C_INT) :: N(*), K(*)
107+
INTEGER(C_INT), VALUE :: RESOL_ID, BLAS_ID, M, LDA, LDC, BATCHCOUNT
108+
INTEGER(C_INT) :: N(*), K(*), LDB(*)
109109
INTEGER(C_INT64_T) :: OFFSETA(*), OFFSETB(*), OFFSETC(*)
110110
REAL(C_FLOAT), VALUE :: ALPHA,BETA
111111
REAL(C_FLOAT) :: A(*), B(*), C(*)
@@ -231,7 +231,7 @@ SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( &
231231
INTEGER(KIND=JPIM) :: LDA
232232
INTEGER(KIND=JPIB) :: OFFSETA(:)
233233
REAL(KIND=JPRD), DIMENSION(*) :: BARRAY
234-
INTEGER(KIND=JPIM) :: LDB
234+
INTEGER(KIND=JPIM) :: LDB(:)
235235
INTEGER(KIND=JPIB) :: OFFSETB(:)
236236
REAL(KIND=JPRD) :: BETA
237237
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
@@ -277,8 +277,8 @@ SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(&
277277
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
278278
INTEGER(KIND=JPIM) :: LDA
279279
INTEGER(KIND=JPIB) :: OFFSETA(:)
280-
REAL(KIND=JPRM), DIMENSION(:,:,:) :: BARRAY
281-
INTEGER(KIND=JPIM) :: LDB
280+
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
281+
INTEGER(KIND=JPIM) :: LDB(:)
282282
INTEGER(KIND=JPIB) :: OFFSETB(:)
283283
REAL(KIND=JPRM) :: BETA
284284
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY

src/trans/gpu/external/setup_trans.F90

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ SUBROUTINE SETUP_TRANS(KSMAX,KDGL,KDLON,KLOEN,LDSPLIT,PSTRET,&
9898
! R. El Khatib 07-Mar-2016 Better flexibility for Legendre polynomials computation in stretched mode
9999
! ------------------------------------------------------------------
100100

101-
USE PARKIND1, ONLY: JPIM, JPRB, JPRD
101+
USE PARKIND1, ONLY: JPIM, JPRB, JPRD, JPIB
102102
USE PARKIND_ECTRANS, ONLY: JPRBT
103103

104104
!ifndef INTERFACE
@@ -166,12 +166,10 @@ SUBROUTINE SETUP_TRANS(KSMAX,KDGL,KDLON,KLOEN,LDSPLIT,PSTRET,&
166166

167167
! Local variables
168168
INTEGER(KIND=JPIM) :: JGL,JRES,IDEF_RESOL
169-
INTEGER(KIND=JPIM) :: JMLOC, KM, ILA, ILS, KMLOC, KDGLU, JK, I, J
170-
171-
INTEGER(KIND=JPIM) :: IPROC, IPROCS, ISTAN, ISTAS, ISL, IGLS, JFLD, IMLOC0(1)
169+
INTEGER(KIND=JPIM) :: JMLOC, KM, ILA, ILS, KDGLU
170+
INTEGER(KIND=JPIM) :: IMLOC0
172171

173172
LOGICAL :: LLP1,LLP2, LLSPSETUPONLY
174-
REAL(KIND=JPRD) :: ZTIME0,ZTIME1,ZTIME2
175173
REAL(KIND=JPHOOK) :: ZHOOK_HANDLE
176174

177175
CHARACTER(LEN=8) :: CENV
@@ -181,7 +179,7 @@ SUBROUTINE SETUP_TRANS(KSMAX,KDGL,KDLON,KLOEN,LDSPLIT,PSTRET,&
181179
#endif
182180
INTEGER :: INUMDEVS, IUNIT, ISTAT, IDEV, MYGPU
183181

184-
#include "user_clock.intfb.h"
182+
REAL(KIND=JPRBT), POINTER :: LOCAL_ARR(:,:)
185183
! ------------------------------------------------------------------
186184

187185
IF (LHOOK) CALL DR_HOOK('SETUP_TRANS',0,ZHOOK_HANDLE)
@@ -470,49 +468,56 @@ SUBROUTINE SETUP_TRANS(KSMAX,KDGL,KDLON,KLOEN,LDSPLIT,PSTRET,&
470468

471469
! Initialize A arrays
472470

473-
ALLOCATE(FG%ZAA(ALIGN(R%NDGNH,8),ALIGN((R%NTMAX+2)/2,8),D%NUMP))
474-
ALLOCATE(FG%ZAS(ALIGN(R%NDGNH,8),ALIGN((R%NTMAX+3)/2,8),D%NUMP))
471+
ALLOCATE(FG%ZAA(D%OFFSETS_GEMM_MATRIX(D%NUMP+1)))
472+
ALLOCATE(FG%ZAS(D%OFFSETS_GEMM_MATRIX(D%NUMP+1)))
475473

476-
FG%ZAA(:,:,:) = 0._JPRBT
477-
FG%ZAS(:,:,:) = 0._JPRBT
474+
FG%ZAA(:) = 0._JPRBT
475+
FG%ZAS(:) = 0._JPRBT
478476

477+
IMLOC0 = 0
479478
DO JMLOC=1,D%NUMP
480479
KM = D%MYMS(JMLOC)
481480
KDGLU = G%NDGLU(KM)
482481
ILA = (R%NSMAX-KM+2)/2
483482
ILS = (R%NSMAX-KM+3)/2
484483

485-
FG%ZAA(1:KDGLU,1:ILA,JMLOC)=S%FA(JMLOC)%RPNMA(1:KDGLU,1:ILA)
486-
FG%ZAS(1:KDGLU,1:ILS,JMLOC)=S%FA(JMLOC)%RPNMS(1:KDGLU,1:ILS)
487-
ENDDO
484+
IF (KM /= 0) THEN
485+
CALL C_F_POINTER(C_LOC(FG%ZAA(1+D%OFFSETS_GEMM_MATRIX(JMLOC))), LOCAL_ARR, &
486+
& (/D%LEGENDRE_MATRIX_STRIDES(JMLOC),ILA/))
487+
LOCAL_ARR(1:KDGLU,1:ILA) = S%FA(JMLOC)%RPNMA(1:KDGLU,1:ILA)
488488

489-
! arrays for m=0 in ledir_mod:
490-
IMLOC0 = FINDLOC(D%MYMS,0)
491-
IF(IMLOC0(1) > 0) THEN
492-
ALLOCATE(FG%ZAA0(SIZE(FG%ZAA,1),SIZE(FG%ZAA,2)))
493-
ALLOCATE(FG%ZAS0(SIZE(FG%ZAS,1),SIZE(FG%ZAS,2)))
494-
FG%ZAA0 = FG%ZAA(:,:,IMLOC0(1))
495-
FG%ZAS0 = FG%ZAS(:,:,IMLOC0(1))
496-
ENDIF
489+
CALL C_F_POINTER(C_LOC(FG%ZAS(1+D%OFFSETS_GEMM_MATRIX(JMLOC))), LOCAL_ARR, &
490+
& (/D%LEGENDRE_MATRIX_STRIDES(JMLOC),ILS/))
491+
LOCAL_ARR(1:KDGLU,1:ILS) = S%FA(JMLOC)%RPNMS(1:KDGLU,1:ILS)
492+
ELSE
493+
IMLOC0 = JMLOC
494+
ALLOCATE(FG%ZAA0(ALIGN(KDGLU,8),ILA))
495+
ALLOCATE(FG%ZAS0(ALIGN(KDGLU,8),ILS))
496+
497+
FG%ZAA0(:,:) = 0
498+
FG%ZAS0(:,:) = 0
499+
FG%ZAA0(1:KDGLU,1:ILA)=S%FA(JMLOC)%RPNMA(1:KDGLU,1:ILA)
500+
FG%ZAS0(1:KDGLU,1:ILS)=S%FA(JMLOC)%RPNMS(1:KDGLU,1:ILS)
501+
ENDIF
502+
ENDDO
497503

498504
ALLOCATE(FG%ZEPSNM(D%NUMP,0:R%NTMAX+2))
499505
FG%ZEPSNM = 0._JPRBT
500-
CALL PREPSNM !Initialize on the host
501-
506+
CALL PREPSNM
502507
WRITE(NOUT,*)'setup_trans: sizes1 NUMP=',D%NUMP
503508
#ifdef ACCGPU
504509
WRITE(NOUT,*) 'Using OpenACC'
505510
#endif
506511
#ifdef OMPGPU
507512
WRITE(NOUT,*) 'Using OpenMP offloading'
508513
#endif
509-
WRITE(NOUT,'(A10,":",I9,"B")') 'FG%ZAS', C_SIZEOF(FG%ZAS(1,1,1))*SIZE(FG%ZAS)
510-
WRITE(NOUT,'(A10,":",I9,"B")') 'FG%ZAA', C_SIZEOF(FG%ZAA(1,1,1))*SIZE(FG%ZAA)
511-
WRITE(NOUT,'(A10,":",I9,"B")') 'FG%ZAS0', C_SIZEOF(FG%ZAS0(1,1))*SIZE(FG%ZAS0)
512-
WRITE(NOUT,'(A10,":",I9,"B")') 'FG%ZAA0', C_SIZEOF(FG%ZAA0(1,1))*SIZE(FG%ZAA0)
513-
WRITE(NOUT,'(A10,":",I9,"B")') 'FG%ZEPSNM', C_SIZEOF(FG%ZEPSNM(1,1))*SIZE(FG%ZEPSNM)
514+
WRITE(NOUT,'(A10,":",I11,"B")') 'FG%ZAS', C_SIZEOF(FG%ZAS(1))*SIZE(FG%ZAS)
515+
WRITE(NOUT,'(A10,":",I11,"B")') 'FG%ZAA', C_SIZEOF(FG%ZAA(1))*SIZE(FG%ZAA)
516+
WRITE(NOUT,'(A10,":",I11,"B")') 'FG%ZAS0', C_SIZEOF(FG%ZAS0(1,1))*SIZE(FG%ZAS0)
517+
WRITE(NOUT,'(A10,":",I11,"B")') 'FG%ZAA0', C_SIZEOF(FG%ZAA0(1,1))*SIZE(FG%ZAA0)
518+
WRITE(NOUT,'(A10,":",I11,"B")') 'FG%ZEPSNM', C_SIZEOF(FG%ZEPSNM(1,1))*SIZE(FG%ZEPSNM)
514519

515-
IF (IMLOC0(1) > 0) THEN
520+
IF (IMLOC0 > 0) THEN
516521
#ifdef ACCGPU
517522
!$ACC ENTER DATA COPYIN(FG%ZAA0,FG%ZAS0) ASYNC(1)
518523
#endif

src/trans/gpu/internal/ledir_mod.F90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
211211
NS(KMLOC) = (R_NSMAX-KM+2)/2
212212
KS(KMLOC) = G_NDGLU(KM)
213213
AOFFSETS(KMLOC) = IIN_STRIDES0*D_OFFSETS_GEMM1(KMLOC)
214-
BOFFSETS(KMLOC) = SIZE(ZAA,1)*SIZE(ZAA,2)*(KMLOC-1)
214+
BOFFSETS(KMLOC) = D%OFFSETS_GEMM_MATRIX(KMLOC)
215215
COFFSETS(KMLOC) = IOUT_STRIDES0*D_OFFSETS_GEMM2(KMLOC)
216216
ENDDO
217217
IF(IMLOC0(1) > 0) THEN
@@ -230,7 +230,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
230230
& 2*KF_FS, NS(:), KS(:), &
231231
& 1.0_JPRBT, &
232232
& ZINPA, IIN_STRIDES0, AOFFSETS, &
233-
& ZAA, SIZE(ZAA,1), BOFFSETS, &
233+
& ZAA, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
234234
& 0.0_JPRBT, &
235235
& ZOUT, IOUT_STRIDES0, COFFSETS, &
236236
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
@@ -331,7 +331,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
331331
NS(KMLOC) = (R_NSMAX-KM+3)/2
332332
KS(KMLOC) = G_NDGLU(KM)
333333
AOFFSETS(KMLOC) = IIN_STRIDES0*D_OFFSETS_GEMM1(KMLOC)
334-
BOFFSETS(KMLOC) = SIZE(ZAS,1)*SIZE(ZAS,2)*(KMLOC-1)
334+
BOFFSETS(KMLOC) = D%OFFSETS_GEMM_MATRIX(KMLOC)
335335
COFFSETS(KMLOC) = IOUT_STRIDES0*D_OFFSETS_GEMM2(KMLOC)
336336
ENDDO
337337
IF(IMLOC0(1) > 0) THEN
@@ -350,7 +350,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
350350
& 2*KF_FS, NS(:), KS(:), &
351351
& 1.0_JPRBT, &
352352
& ZINPS, IIN_STRIDES0, AOFFSETS, &
353-
& ZAS, SIZE(ZAS,1), BOFFSETS, &
353+
& ZAS, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
354354
& 0.0_JPRBT, &
355355
& ZOUT, IOUT_STRIDES0, COFFSETS, &
356356
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)

src/trans/gpu/internal/leinv_mod.F90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ SUBROUTINE LEINV(ALLOCATOR,PIA,ZINP,ZINP0,ZOUTS,ZOUTA,ZOUTS0,ZOUTA0,KF_LEG)
271271
KS(KMLOC) = (R_NSMAX-KM+2)/2
272272
NS(KMLOC) = G_NDGLU(KM)
273273
AOFFSETS(KMLOC) = IIN_STRIDES0*D_OFFSETS_GEMM2(KMLOC)
274-
BOFFSETS(KMLOC) = SIZE(ZAA,1)*SIZE(ZAA,2)*(KMLOC-1)
274+
BOFFSETS(KMLOC) = D%OFFSETS_GEMM_MATRIX(KMLOC)
275275
COFFSETS(KMLOC) = IOUT_STRIDES0*D_OFFSETS_GEMM1(KMLOC)
276276
ENDDO
277277
IF(IMLOC0(1) > 0) THEN
@@ -290,7 +290,7 @@ SUBROUTINE LEINV(ALLOCATOR,PIA,ZINP,ZINP0,ZOUTS,ZOUTA,ZOUTS0,ZOUTA0,KF_LEG)
290290
& 2*KF_LEG, NS(:), KS(:), &
291291
& 1.0_JPRBT, &
292292
& ZINP, IIN_STRIDES0, AOFFSETS, &
293-
& ZAA, SIZE(ZAA,1), BOFFSETS, &
293+
& ZAA, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
294294
& 0.0_JPRBT, &
295295
& ZOUTA, IOUT_STRIDES0, COFFSETS, &
296296
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
@@ -411,7 +411,7 @@ SUBROUTINE LEINV(ALLOCATOR,PIA,ZINP,ZINP0,ZOUTS,ZOUTA,ZOUTS0,ZOUTA0,KF_LEG)
411411
KS(KMLOC) = (R_NSMAX-KM+3)/2
412412
NS(KMLOC) = G_NDGLU(KM)
413413
AOFFSETS(KMLOC) = IIN_STRIDES0*D_OFFSETS_GEMM2(KMLOC)
414-
BOFFSETS(KMLOC) = SIZE(ZAS,1)*SIZE(ZAS,2)*(KMLOC-1)
414+
BOFFSETS(KMLOC) = D%OFFSETS_GEMM_MATRIX(KMLOC)
415415
COFFSETS(KMLOC) = IOUT_STRIDES0*D_OFFSETS_GEMM1(KMLOC)
416416
ENDDO
417417
IF(IMLOC0(1) > 0) THEN
@@ -430,7 +430,7 @@ SUBROUTINE LEINV(ALLOCATOR,PIA,ZINP,ZINP0,ZOUTS,ZOUTA,ZOUTS0,ZOUTA0,KF_LEG)
430430
& 2*KF_LEG, NS(:), KS(:), &
431431
& 1.0_JPRBT, &
432432
& ZINP, IIN_STRIDES0, AOFFSETS, &
433-
& ZAS, SIZE(ZAS,1), BOFFSETS, &
433+
& ZAS, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
434434
& 0.0_JPRBT, &
435435
& ZOUTS, IOUT_STRIDES0, COFFSETS, &
436436
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)

0 commit comments

Comments
 (0)