Skip to content

Commit 297af36

Browse files
authored
Merge pull request #1819 from CEED/zach/mi300a-fixes
Reworks the stream implementation for `/gpu/hip/gen` to avoid creating and destroying streams on every operator apply. Updates `hipblas` calls to only sync stream, this matters on MI300A since `hipblas` seems to use an async stream. Avoids a full device sync. Also makes working vectors come from the `Vector` object delegate to avoid bad ref behavior.
2 parents 7b3ff06 + b46df0d commit 297af36

File tree

6 files changed

+110
-26
lines changed

6 files changed

+110
-26
lines changed

backends/hip-gen/ceed-hip-gen-operator.c

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,20 @@
2222
static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
2323
Ceed ceed;
2424
CeedOperator_Hip_gen *impl;
25+
bool is_composite;
2526

2627
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2728
CeedCallBackend(CeedOperatorGetData(op, &impl));
29+
CeedCallBackend(CeedOperatorIsComposite(op, &is_composite));
30+
if (is_composite) {
31+
CeedInt num_suboperators;
32+
33+
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
34+
for (CeedInt i = 0; i < num_suboperators; i++) {
35+
if (impl->streams[i]) CeedCallHip(ceed, hipStreamDestroy(impl->streams[i]));
36+
impl->streams[i] = NULL;
37+
}
38+
}
2839
if (impl->module) CeedCallHip(ceed, hipModuleUnload(impl->module));
2940
if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
3041
CeedCallBackend(CeedFree(&impl));
@@ -239,28 +250,35 @@ static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, C
239250
}
240251

241252
static int CeedOperatorApplyAddComposite_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
242-
bool is_run_good[CEED_COMPOSITE_MAX] = {false};
243-
CeedInt num_suboperators;
244-
const CeedScalar *input_arr = NULL;
245-
CeedScalar *output_arr = NULL;
246-
Ceed ceed;
247-
CeedOperator *sub_operators;
253+
bool is_run_good[CEED_COMPOSITE_MAX] = {true};
254+
CeedInt num_suboperators;
255+
const CeedScalar *input_arr = NULL;
256+
CeedScalar *output_arr;
257+
Ceed ceed;
258+
CeedOperator_Hip_gen *impl;
259+
CeedOperator *sub_operators;
248260

249261
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
250-
CeedCall(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
251-
CeedCall(CeedCompositeOperatorGetSubList(op, &sub_operators));
262+
CeedCallBackend(CeedOperatorGetData(op, &impl));
263+
CeedCallBackend(CeedCompositeOperatorGetNumSub(op, &num_suboperators));
264+
CeedCallBackend(CeedCompositeOperatorGetSubList(op, &sub_operators));
252265
if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(input_vec, CEED_MEM_DEVICE, &input_arr));
253266
if (output_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArray(output_vec, CEED_MEM_DEVICE, &output_arr));
254267
for (CeedInt i = 0; i < num_suboperators; i++) {
255268
CeedInt num_elem = 0;
256269

257-
CeedCall(CeedOperatorGetNumElements(sub_operators[i], &num_elem));
270+
CeedCallBackend(CeedOperatorGetNumElements(sub_operators[i], &num_elem));
258271
if (num_elem > 0) {
259-
hipStream_t stream = NULL;
272+
if (!impl->streams[i]) CeedCallHip(ceed, hipStreamCreate(&impl->streams[i]));
273+
CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], impl->streams[i], input_arr, output_arr, &is_run_good[i], request));
274+
} else {
275+
is_run_good[i] = true;
276+
}
277+
}
260278

261-
CeedCallHip(ceed, hipStreamCreate(&stream));
262-
CeedCallBackend(CeedOperatorApplyAddCore_Hip_gen(sub_operators[i], stream, input_arr, output_arr, &is_run_good[i], request));
263-
CeedCallHip(ceed, hipStreamDestroy(stream));
279+
for (CeedInt i = 0; i < num_suboperators; i++) {
280+
if (impl->streams[i]) {
281+
if (is_run_good[i]) CeedCallHip(ceed, hipStreamSynchronize(impl->streams[i]));
264282
}
265283
}
266284
if (input_vec != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorRestoreArrayRead(input_vec, &input_arr));

backends/hip-gen/ceed-hip-gen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ typedef struct {
1717
CeedInt Q, Q_1d;
1818
CeedInt max_P_1d;
1919
CeedInt thread_1d;
20+
hipStream_t streams[CEED_COMPOSITE_MAX];
2021
hipModule_t module;
2122
hipFunction_t op;
2223
FieldsInt_Hip indices;

backends/hip-ref/ceed-hip-ref-vector.c

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -305,19 +305,21 @@ static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize st
305305
// Set value for synced device/host array
306306
if (impl->d_array) {
307307
CeedScalar *copy_array;
308+
Ceed ceed;
308309

310+
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
309311
CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
310312
#if (HIP_VERSION >= 60000000)
311313
hipblasHandle_t handle;
312-
Ceed ceed;
313-
314-
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
314+
hipStream_t stream;
315315
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
316+
CeedCallHipblas(ceed, hipblasGetStream(handle, &stream));
316317
#if defined(CEED_SCALAR_IS_FP32)
317318
CeedCallHipblas(ceed, hipblasScopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
318319
#else /* CEED_SCALAR */
319320
CeedCallHipblas(ceed, hipblasDcopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
320321
#endif /* CEED_SCALAR */
322+
CeedCallHip(ceed, hipStreamSynchronize(stream));
321323
#else /* HIP_VERSION */
322324
CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, stop, step, copy_array));
323325
#endif /* HIP_VERSION */
@@ -557,14 +559,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
557559
const CeedScalar *d_array;
558560
CeedVector_Hip *impl;
559561
hipblasHandle_t handle;
562+
hipStream_t stream;
560563
Ceed_Hip *hip_data;
561564

562565
CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
563566
CeedCallBackend(CeedGetData(ceed, &hip_data));
564567
CeedCallBackend(CeedVectorGetData(vec, &impl));
565568
CeedCallBackend(CeedVectorGetLength(vec, &length));
566569
CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
567-
570+
CeedCallHipblas(ceed, hipblasGetStream(handle, &stream));
568571
#if (HIP_VERSION < 60000000)
569572
// With ROCm 6, we can use the 64-bit integer interface. Prior to that,
570573
// we need to check if the vector is too long to handle with int32,
@@ -581,6 +584,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
581584
#if defined(CEED_SCALAR_IS_FP32)
582585
#if (HIP_VERSION >= 60000000) // We have ROCm 6, and can use 64-bit integers
583586
CeedCallHipblas(ceed, hipblasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
587+
CeedCallHip(ceed, hipStreamSynchronize(stream));
584588
#else /* HIP_VERSION */
585589
float sub_norm = 0.0;
586590
float *d_array_start;
@@ -591,12 +595,14 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
591595
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
592596

593597
CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
598+
CeedCallHip(ceed, hipStreamSynchronize(stream));
594599
*norm += sub_norm;
595600
}
596601
#endif /* HIP_VERSION */
597602
#else /* CEED_SCALAR */
598603
#if (HIP_VERSION >= 60000000)
599604
CeedCallHipblas(ceed, hipblasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
605+
CeedCallHip(ceed, hipStreamSynchronize(stream));
600606
#else /* HIP_VERSION */
601607
double sub_norm = 0.0;
602608
double *d_array_start;
@@ -607,6 +613,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
607613
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
608614

609615
CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
616+
CeedCallHip(ceed, hipStreamSynchronize(stream));
610617
*norm += sub_norm;
611618
}
612619
#endif /* HIP_VERSION */
@@ -617,6 +624,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
617624
#if defined(CEED_SCALAR_IS_FP32)
618625
#if (HIP_VERSION >= 60000000)
619626
CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
627+
CeedCallHip(ceed, hipStreamSynchronize(stream));
620628
#else /* HIP_VERSION */
621629
float sub_norm = 0.0, norm_sum = 0.0;
622630
float *d_array_start;
@@ -627,13 +635,15 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
627635
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
628636

629637
CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
638+
CeedCallHip(ceed, hipStreamSynchronize(stream));
630639
norm_sum += sub_norm * sub_norm;
631640
}
632641
*norm = sqrt(norm_sum);
633642
#endif /* HIP_VERSION */
634643
#else /* CEED_SCALAR */
635644
#if (HIP_VERSION >= 60000000)
636645
CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
646+
CeedCallHip(ceed, hipStreamSynchronize(stream));
637647
#else /* HIP_VERSION */
638648
double sub_norm = 0.0, norm_sum = 0.0;
639649
double *d_array_start;
@@ -644,6 +654,7 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
644654
CeedInt sub_length = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
645655

646656
CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
657+
CeedCallHip(ceed, hipStreamSynchronize(stream));
647658
norm_sum += sub_norm * sub_norm;
648659
}
649660
*norm = sqrt(norm_sum);
@@ -658,7 +669,8 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
658669
CeedScalar norm_no_abs;
659670

660671
CeedCallHipblas(ceed, hipblasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
661-
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
672+
CeedCallHip(ceed, hipMemcpyAsync(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
673+
CeedCallHip(ceed, hipStreamSynchronize(stream));
662674
*norm = fabs(norm_no_abs);
663675
#else /* HIP_VERSION */
664676
CeedInt index;
@@ -672,10 +684,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
672684

673685
CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
674686
if (hip_data->has_unified_addressing) {
675-
CeedCallHip(ceed, hipDeviceSynchronize());
687+
CeedCallHip(ceed, hipStreamSynchronize(stream));
676688
sub_max = fabs(d_array[index - 1]);
677689
} else {
678-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
690+
CeedCallHip(ceed, hipMemcpyAsync(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
691+
CeedCallHip(ceed, hipStreamSynchronize(stream));
679692
}
680693
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
681694
}
@@ -688,10 +701,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
688701

689702
CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
690703
if (hip_data->has_unified_addressing) {
691-
CeedCallHip(ceed, hipDeviceSynchronize());
704+
CeedCallHip(ceed, hipStreamSynchronize(stream));
692705
norm_no_abs = fabs(d_array[index - 1]);
693706
} else {
694-
CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
707+
CeedCallHip(ceed, hipMemcpyAsync(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
708+
CeedCallHip(ceed, hipStreamSynchronize(stream));
695709
}
696710
*norm = fabs(norm_no_abs);
697711
#else /* HIP_VERSION */
@@ -706,10 +720,11 @@ static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *nor
706720

707721
CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
708722
if (hip_data->has_unified_addressing) {
709-
CeedCallHip(ceed, hipDeviceSynchronize());
723+
CeedCallHip(ceed, hipStreamSynchronize(stream));
710724
sub_max = fabs(d_array[index - 1]);
711725
} else {
712-
CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
726+
CeedCallHip(ceed, hipMemcpyAsync(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost, stream));
727+
CeedCallHip(ceed, hipStreamSynchronize(stream));
713728
}
714729
if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
715730
}
@@ -780,13 +795,16 @@ static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
780795
if (impl->d_array) {
781796
#if (HIP_VERSION >= 60000000)
782797
hipblasHandle_t handle;
798+
hipStream_t stream;
783799

784800
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
801+
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasGetStream(handle, &stream));
785802
#if defined(CEED_SCALAR_IS_FP32)
786803
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
787804
#else /* CEED_SCALAR */
788805
CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
789806
#endif /* CEED_SCALAR */
807+
CeedCallHip(CeedVectorReturnCeed(x), hipStreamSynchronize(stream));
790808
#else /* HIP_VERSION */
791809
CeedCallBackend(CeedDeviceScale_Hip(impl->d_array, alpha, length));
792810
#endif /* HIP_VERSION */
@@ -827,13 +845,16 @@ static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
827845
CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
828846
#if (HIP_VERSION >= 60000000)
829847
hipblasHandle_t handle;
848+
hipStream_t stream;
830849

831-
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
850+
CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
851+
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasGetStream(handle, &stream));
832852
#if defined(CEED_SCALAR_IS_FP32)
833853
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
834854
#else /* CEED_SCALAR */
835855
CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
836856
#endif /* CEED_SCALAR */
857+
CeedCallHip(CeedVectorReturnCeed(y), hipStreamSynchronize(stream));
837858
#else /* HIP_VERSION */
838859
CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
839860
#endif /* HIP_VERSION */

backends/hip-ref/ceed-hip-ref.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ int CeedGetHipblasHandle_Hip(Ceed ceed, hipblasHandle_t *handle) {
2929
Ceed_Hip *data;
3030

3131
CeedCallBackend(CeedGetData(ceed, &data));
32-
if (!data->hipblas_handle) CeedCallHipblas(ceed, hipblasCreate(&data->hipblas_handle));
32+
if (!data->hipblas_handle) {
33+
CeedCallHipblas(ceed, hipblasCreate(&data->hipblas_handle));
34+
CeedCallHipblas(ceed, hipblasSetPointerMode(data->hipblas_handle, HIPBLAS_POINTER_MODE_HOST));
35+
}
3336
*handle = data->hipblas_handle;
3437
return CEED_ERROR_SUCCESS;
3538
}

backends/hip-shared/ceed-hip-shared-basis.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ static int CeedBasisApplyAtPointsCore_Hip_shared(CeedBasis basis, bool apply_add
489489
CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
490490
if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
491491
if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
492+
CeedCallBackend(CeedDestroy(&ceed));
492493
return CEED_ERROR_SUCCESS;
493494
}
494495

@@ -644,6 +645,7 @@ static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
644645
CeedCallHip(ceed, hipFree(data->d_collo_grad_1d));
645646
CeedCallHip(ceed, hipFree(data->d_chebyshev_interp_1d));
646647
CeedCallBackend(CeedFree(&data));
648+
CeedCallBackend(CeedDestroy(&ceed));
647649
return CEED_ERROR_SUCCESS;
648650
}
649651

@@ -737,6 +739,7 @@ int CeedBasisCreateH1_Hip_shared(CeedElemTopology topo, CeedInt dim, CeedInt num
737739
if (((size_t)num_nodes * (size_t)num_qpts * (size_t)dim + (size_t)CeedIntMax(num_nodes, num_qpts)) * sizeof(CeedScalar) >
738740
hip_data->device_prop.sharedMemPerBlock) {
739741
CeedCallBackend(CeedBasisCreateH1Fallback(ceed, topo, dim, num_nodes, num_qpts, interp, grad, q_ref, q_weight, basis));
742+
CeedCallBackend(CeedDestroy(&ceed));
740743
return CEED_ERROR_SUCCESS;
741744
}
742745
}

interface/ceed.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,15 @@ int CeedReference(Ceed ceed) {
827827
@ref Developer
828828
**/
829829
int CeedGetWorkVectorMemoryUsage(Ceed ceed, CeedScalar *usage_mb) {
830+
if (!ceed->VectorCreate) {
831+
Ceed delegate;
832+
833+
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Vector"));
834+
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement VectorCreate");
835+
CeedCall(CeedGetWorkVectorMemoryUsage(delegate, usage_mb));
836+
CeedCall(CeedDestroy(&delegate));
837+
return CEED_ERROR_SUCCESS;
838+
}
830839
*usage_mb = 0.0;
831840
if (ceed->work_vectors) {
832841
for (CeedInt i = 0; i < ceed->work_vectors->num_vecs; i++) {
@@ -852,6 +861,15 @@ int CeedGetWorkVectorMemoryUsage(Ceed ceed, CeedScalar *usage_mb) {
852861
@ref Backend
853862
**/
854863
int CeedClearWorkVectors(Ceed ceed, CeedSize min_len) {
864+
if (!ceed->VectorCreate) {
865+
Ceed delegate;
866+
867+
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Vector"));
868+
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement VectorCreate");
869+
CeedCall(CeedClearWorkVectors(delegate, min_len));
870+
CeedCall(CeedDestroy(&delegate));
871+
return CEED_ERROR_SUCCESS;
872+
}
855873
if (!ceed->work_vectors) return CEED_ERROR_SUCCESS;
856874
for (CeedInt i = 0; i < ceed->work_vectors->num_vecs; i++) {
857875
if (ceed->work_vectors->is_in_use[i]) continue;
@@ -890,6 +908,16 @@ int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec) {
890908
CeedInt i = 0;
891909
CeedScalar usage_mb;
892910

911+
if (!ceed->VectorCreate) {
912+
Ceed delegate;
913+
914+
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Vector"));
915+
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement VectorCreate");
916+
CeedCall(CeedGetWorkVector(delegate, len, vec));
917+
CeedCall(CeedDestroy(&delegate));
918+
return CEED_ERROR_SUCCESS;
919+
}
920+
893921
if (!ceed->work_vectors) CeedCall(CeedWorkVectorsCreate(ceed));
894922

895923
// Search for big enough work vector
@@ -936,6 +964,16 @@ int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec) {
936964
@ref Backend
937965
**/
938966
int CeedRestoreWorkVector(Ceed ceed, CeedVector *vec) {
967+
if (!ceed->VectorCreate) {
968+
Ceed delegate;
969+
970+
CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Vector"));
971+
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement VectorCreate");
972+
CeedCall(CeedRestoreWorkVector(delegate, vec));
973+
CeedCall(CeedDestroy(&delegate));
974+
return CEED_ERROR_SUCCESS;
975+
}
976+
939977
for (CeedInt i = 0; i < ceed->work_vectors->num_vecs; i++) {
940978
if (*vec == ceed->work_vectors->vecs[i]) {
941979
CeedCheck(ceed->work_vectors->is_in_use[i], ceed, CEED_ERROR_ACCESS, "Work vector %" CeedSize_FMT " was not checked out but is being returned");

0 commit comments

Comments
 (0)