@@ -305,19 +305,21 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
305
305
// Set value for synced device/host array
306
306
if (impl -> d_array ) {
307
307
CeedScalar * copy_array ;
308
+ Ceed ceed ;
308
309
310
+ CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
309
311
CeedCallBackend (CeedVectorGetArray (vec_copy , CEED_MEM_DEVICE , & copy_array ));
310
312
#if (HIP_VERSION >= 60000000 )
311
313
hipblasHandle_t handle ;
312
- Ceed ceed ;
313
-
314
- CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
314
+ hipStream_t stream ;
315
315
CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
316
+ CeedCallHipblas (ceed , hipblasGetStream (handle , & stream ));
316
317
#if defined(CEED_SCALAR_IS_FP32 )
317
318
CeedCallHipblas (ceed , hipblasScopy_64 (handle , (int64_t )(stop - start ), impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
318
319
#else /* CEED_SCALAR */
319
320
CeedCallHipblas (ceed , hipblasDcopy_64 (handle , (int64_t )(stop - start ), impl -> d_array + start , (int64_t )step , copy_array + start , (int64_t )step ));
320
321
#endif /* CEED_SCALAR */
322
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
321
323
#else /* HIP_VERSION */
322
324
CeedCallBackend (CeedDeviceCopyStrided_Hip (impl -> d_array , start , stop , step , copy_array ));
323
325
#endif /* HIP_VERSION */
@@ -557,14 +559,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
557
559
const CeedScalar * d_array ;
558
560
CeedVector_Hip * impl ;
559
561
hipblasHandle_t handle ;
562
+ hipStream_t stream ;
560
563
Ceed_Hip * hip_data ;
561
564
562
565
CeedCallBackend (CeedVectorGetCeed (vec , & ceed ));
563
566
CeedCallBackend (CeedGetData (ceed , & hip_data ));
564
567
CeedCallBackend (CeedVectorGetData (vec , & impl ));
565
568
CeedCallBackend (CeedVectorGetLength (vec , & length ));
566
569
CeedCallBackend (CeedGetHipblasHandle_Hip (ceed , & handle ));
567
-
570
+ CeedCallHipblas ( ceed , hipblasGetStream ( handle , & stream ));
568
571
#if (HIP_VERSION < 60000000 )
569
572
// With ROCm 6, we can use the 64-bit integer interface. Prior to that,
570
573
// we need to check if the vector is too long to handle with int32,
@@ -581,6 +584,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
581
584
#if defined(CEED_SCALAR_IS_FP32 )
582
585
#if (HIP_VERSION >= 60000000 ) // We have ROCm 6, and can use 64-bit integers
583
586
CeedCallHipblas (ceed , hipblasSasum_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
587
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
584
588
#else /* HIP_VERSION */
585
589
float sub_norm = 0.0 ;
586
590
float * d_array_start ;
@@ -591,12 +595,14 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
591
595
CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
592
596
593
597
CeedCallHipblas (ceed , hipblasSasum (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
598
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
594
599
* norm += sub_norm ;
595
600
}
596
601
#endif /* HIP_VERSION */
597
602
#else /* CEED_SCALAR */
598
603
#if (HIP_VERSION >= 60000000 )
599
604
CeedCallHipblas (ceed , hipblasDasum_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
605
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
600
606
#else /* HIP_VERSION */
601
607
double sub_norm = 0.0 ;
602
608
double * d_array_start ;
@@ -607,6 +613,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
607
613
CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
608
614
609
615
CeedCallHipblas (ceed , hipblasDasum (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & sub_norm ));
616
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
610
617
* norm += sub_norm ;
611
618
}
612
619
#endif /* HIP_VERSION */
@@ -617,6 +624,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
617
624
#if defined(CEED_SCALAR_IS_FP32 )
618
625
#if (HIP_VERSION >= 60000000 )
619
626
CeedCallHipblas (ceed , hipblasSnrm2_64 (handle , (int64_t )length , (float * )d_array , 1 , (float * )norm ));
627
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
620
628
#else /* HIP_VERSION */
621
629
float sub_norm = 0.0 , norm_sum = 0.0 ;
622
630
float * d_array_start ;
@@ -627,13 +635,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
627
635
CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
628
636
629
637
CeedCallHipblas (ceed , hipblasSnrm2 (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & sub_norm ));
638
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
630
639
norm_sum += sub_norm * sub_norm ;
631
640
}
632
641
* norm = sqrt (norm_sum );
633
642
#endif /* HIP_VERSION */
634
643
#else /* CEED_SCALAR */
635
644
#if (HIP_VERSION >= 60000000 )
636
645
CeedCallHipblas (ceed , hipblasDnrm2_64 (handle , (int64_t )length , (double * )d_array , 1 , (double * )norm ));
646
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
637
647
#else /* HIP_VERSION */
638
648
double sub_norm = 0.0 , norm_sum = 0.0 ;
639
649
double * d_array_start ;
@@ -644,6 +654,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
644
654
CeedInt sub_length = (i == num_calls - 1 ) ? (CeedInt )(remaining_length ) : INT_MAX ;
645
655
646
656
CeedCallHipblas (ceed , hipblasDnrm2 (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & sub_norm ));
657
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
647
658
norm_sum += sub_norm * sub_norm ;
648
659
}
649
660
* norm = sqrt (norm_sum );
@@ -658,7 +669,8 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
658
669
CeedScalar norm_no_abs ;
659
670
660
671
CeedCallHipblas (ceed , hipblasIsamax_64 (handle , (int64_t )length , (float * )d_array , 1 , & index ));
661
- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
672
+ CeedCallHip (ceed , hipMemcpyAsync (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
673
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
662
674
* norm = fabs (norm_no_abs );
663
675
#else /* HIP_VERSION */
664
676
CeedInt index ;
@@ -672,10 +684,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
672
684
673
685
CeedCallHipblas (ceed , hipblasIsamax (handle , (CeedInt )sub_length , (float * )d_array_start , 1 , & index ));
674
686
if (hip_data -> has_unified_addressing ) {
675
- CeedCallHip (ceed , hipDeviceSynchronize ( ));
687
+ CeedCallHip (ceed , hipStreamSynchronize ( stream ));
676
688
sub_max = fabs (d_array [index - 1 ]);
677
689
} else {
678
- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
690
+ CeedCallHip (ceed , hipMemcpyAsync (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
691
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
679
692
}
680
693
if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
681
694
}
@@ -688,10 +701,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
688
701
689
702
CeedCallHipblas (ceed , hipblasIdamax_64 (handle , (int64_t )length , (double * )d_array , 1 , & index ));
690
703
if (hip_data -> has_unified_addressing ) {
691
- CeedCallHip (ceed , hipDeviceSynchronize ( ));
704
+ CeedCallHip (ceed , hipStreamSynchronize ( stream ));
692
705
norm_no_abs = fabs (d_array [index - 1 ]);
693
706
} else {
694
- CeedCallHip (ceed , hipMemcpy (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
707
+ CeedCallHip (ceed , hipMemcpyAsync (& norm_no_abs , impl -> d_array + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
708
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
695
709
}
696
710
* norm = fabs (norm_no_abs );
697
711
#else /* HIP_VERSION */
@@ -706,10 +720,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
706
720
707
721
CeedCallHipblas (ceed , hipblasIdamax (handle , (CeedInt )sub_length , (double * )d_array_start , 1 , & index ));
708
722
if (hip_data -> has_unified_addressing ) {
709
- CeedCallHip (ceed , hipDeviceSynchronize ( ));
723
+ CeedCallHip (ceed , hipStreamSynchronize ( stream ));
710
724
sub_max = fabs (d_array [index - 1 ]);
711
725
} else {
712
- CeedCallHip (ceed , hipMemcpy (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost ));
726
+ CeedCallHip (ceed , hipMemcpyAsync (& sub_max , d_array_start + index - 1 , sizeof (CeedScalar ), hipMemcpyDeviceToHost , stream ));
727
+ CeedCallHip (ceed , hipStreamSynchronize (stream ));
713
728
}
714
729
if (fabs (sub_max ) > current_max ) current_max = fabs (sub_max );
715
730
}
@@ -780,13 +795,16 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
780
795
if (impl -> d_array ) {
781
796
#if (HIP_VERSION >= 60000000 )
782
797
hipblasHandle_t handle ;
798
+ hipStream_t stream ;
783
799
784
800
CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (x ), & handle ));
801
+ CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasGetStream (handle , & stream ));
785
802
#if defined(CEED_SCALAR_IS_FP32 )
786
803
CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasSscal_64 (handle , (int64_t )length , & alpha , impl -> d_array , 1 ));
787
804
#else /* CEED_SCALAR */
788
805
CeedCallHipblas (CeedVectorReturnCeed (x ), hipblasDscal_64 (handle , (int64_t )length , & alpha , impl -> d_array , 1 ));
789
806
#endif /* CEED_SCALAR */
807
+ CeedCallHip (CeedVectorReturnCeed (x ), hipStreamSynchronize (stream ));
790
808
#else /* HIP_VERSION */
791
809
CeedCallBackend (CeedDeviceScale_Hip (impl -> d_array , alpha , length ));
792
810
#endif /* HIP_VERSION */
@@ -827,13 +845,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
827
845
CeedCallBackend (CeedVectorSyncArray (x , CEED_MEM_DEVICE ));
828
846
#if (HIP_VERSION >= 60000000 )
829
847
hipblasHandle_t handle ;
848
+ hipStream_t stream ;
830
849
831
- CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (y ), & handle ));
850
+ CeedCallBackend (CeedGetHipblasHandle_Hip (CeedVectorReturnCeed (x ), & handle ));
851
+ CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasGetStream (handle , & stream ));
832
852
#if defined(CEED_SCALAR_IS_FP32 )
833
853
CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasSaxpy_64 (handle , (int64_t )length , & alpha , x_impl -> d_array , 1 , y_impl -> d_array , 1 ));
834
854
#else /* CEED_SCALAR */
835
855
CeedCallHipblas (CeedVectorReturnCeed (y ), hipblasDaxpy_64 (handle , (int64_t )length , & alpha , x_impl -> d_array , 1 , y_impl -> d_array , 1 ));
836
856
#endif /* CEED_SCALAR */
857
+ CeedCallHip (CeedVectorReturnCeed (y ), hipStreamSynchronize (stream ));
837
858
#else /* HIP_VERSION */
838
859
CeedCallBackend (CeedDeviceAXPY_Hip (y_impl -> d_array , alpha , x_impl -> d_array , length ));
839
860
#endif /* HIP_VERSION */
0 commit comments