@@ -28,6 +28,8 @@ static int CeedOperatorDestroy_Cuda_gen(CeedOperator op) {
28
28
CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
29
29
CeedCallBackend (CeedOperatorGetData (op , & impl ));
30
30
if (impl -> module ) CeedCallCuda (ceed , cuModuleUnload (impl -> module ));
31
+ if (impl -> module_assemble_full ) CeedCallCuda (ceed , cuModuleUnload (impl -> module_assemble_full ));
32
+ if (impl -> module_assemble_diagonal ) CeedCallCuda (ceed , cuModuleUnload (impl -> module_assemble_diagonal ));
31
33
if (impl -> points .num_per_elem ) CeedCallCuda (ceed , cudaFree ((void * * )impl -> points .num_per_elem ));
32
34
CeedCallBackend (CeedFree (& impl ));
33
35
CeedCallBackend (CeedDestroy (& ceed ));
@@ -333,11 +335,173 @@ static int CeedOperatorApplyAddComposite_Cuda_gen(CeedOperator op, CeedVector in
333
335
return CEED_ERROR_SUCCESS ;
334
336
}
335
337
338
+ //------------------------------------------------------------------------------
339
+ // AtPoints diagonal assembly
340
+ //------------------------------------------------------------------------------
341
+ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen (CeedOperator op , CeedVector assembled , CeedRequest * request ) {
342
+ Ceed ceed ;
343
+ CeedOperator_Cuda_gen * data ;
344
+
345
+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
346
+ CeedCallBackend (CeedOperatorGetData (op , & data ));
347
+
348
+ // Build the assembly kernel
349
+ if (!data -> assemble_diagonal && !data -> use_assembly_fallback ) {
350
+ bool is_build_good = false;
351
+ CeedInt num_active_bases_in , num_active_bases_out ;
352
+ CeedOperatorAssemblyData assembly_data ;
353
+
354
+ CeedCallBackend (CeedOperatorGetOperatorAssemblyData (op , & assembly_data ));
355
+ CeedCallBackend (
356
+ CeedOperatorAssemblyDataGetEvalModes (assembly_data , & num_active_bases_in , NULL , NULL , NULL , & num_active_bases_out , NULL , NULL , NULL , NULL ));
357
+ if (num_active_bases_in == num_active_bases_out ) {
358
+ CeedCallBackend (CeedOperatorBuildKernel_Cuda_gen (op , & is_build_good ));
359
+ if (is_build_good ) CeedCallBackend (CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Cuda_gen (op , & is_build_good ));
360
+ }
361
+ if (!is_build_good ) data -> use_assembly_fallback = true;
362
+ }
363
+
364
+ // Try assembly
365
+ if (!data -> use_assembly_fallback ) {
366
+ bool is_run_good = true;
367
+ Ceed_Cuda * cuda_data ;
368
+ CeedInt num_elem , num_input_fields , num_output_fields ;
369
+ CeedEvalMode eval_mode ;
370
+ CeedScalar * assembled_array ;
371
+ CeedQFunctionField * qf_input_fields , * qf_output_fields ;
372
+ CeedQFunction_Cuda_gen * qf_data ;
373
+ CeedQFunction qf ;
374
+ CeedOperatorField * op_input_fields , * op_output_fields ;
375
+
376
+ CeedCallBackend (CeedGetData (ceed , & cuda_data ));
377
+ CeedCallBackend (CeedOperatorGetQFunction (op , & qf ));
378
+ CeedCallBackend (CeedQFunctionGetData (qf , & qf_data ));
379
+ CeedCallBackend (CeedOperatorGetNumElements (op , & num_elem ));
380
+ CeedCallBackend (CeedOperatorGetFields (op , & num_input_fields , & op_input_fields , & num_output_fields , & op_output_fields ));
381
+ CeedCallBackend (CeedQFunctionGetFields (qf , NULL , & qf_input_fields , NULL , & qf_output_fields ));
382
+
383
+ // Input vectors
384
+ for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
385
+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
386
+ if (eval_mode == CEED_EVAL_WEIGHT ) { // Skip
387
+ data -> fields .inputs [i ] = NULL ;
388
+ } else {
389
+ bool is_active ;
390
+ CeedVector vec ;
391
+
392
+ // Get input vector
393
+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & vec ));
394
+ is_active = vec == CEED_VECTOR_ACTIVE ;
395
+ if (is_active ) data -> fields .inputs [i ] = NULL ;
396
+ else CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & data -> fields .inputs [i ]));
397
+ CeedCallBackend (CeedVectorDestroy (& vec ));
398
+ }
399
+ }
400
+
401
+ // Point coordinates
402
+ {
403
+ CeedVector vec ;
404
+
405
+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , NULL , & vec ));
406
+ CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & data -> points .coords ));
407
+ CeedCallBackend (CeedVectorDestroy (& vec ));
408
+
409
+ // Points per elem
410
+ if (num_elem != data -> points .num_elem ) {
411
+ CeedInt * points_per_elem ;
412
+ const CeedInt num_bytes = num_elem * sizeof (CeedInt );
413
+ CeedElemRestriction rstr_points = NULL ;
414
+
415
+ data -> points .num_elem = num_elem ;
416
+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , & rstr_points , NULL ));
417
+ CeedCallBackend (CeedCalloc (num_elem , & points_per_elem ));
418
+ for (CeedInt e = 0 ; e < num_elem ; e ++ ) {
419
+ CeedInt num_points_elem ;
420
+
421
+ CeedCallBackend (CeedElemRestrictionGetNumPointsInElement (rstr_points , e , & num_points_elem ));
422
+ points_per_elem [e ] = num_points_elem ;
423
+ }
424
+ if (data -> points .num_per_elem ) CeedCallCuda (ceed , cudaFree ((void * * )data -> points .num_per_elem ));
425
+ CeedCallCuda (ceed , cudaMalloc ((void * * )& data -> points .num_per_elem , num_bytes ));
426
+ CeedCallCuda (ceed , cudaMemcpy ((void * )data -> points .num_per_elem , points_per_elem , num_bytes , cudaMemcpyHostToDevice ));
427
+ CeedCallBackend (CeedElemRestrictionDestroy (& rstr_points ));
428
+ CeedCallBackend (CeedFree (& points_per_elem ));
429
+ }
430
+ }
431
+
432
+ // Get context data
433
+ CeedCallBackend (CeedQFunctionGetInnerContextData (qf , CEED_MEM_DEVICE , & qf_data -> d_c ));
434
+
435
+ // Assembly array
436
+ CeedCallBackend (CeedVectorGetArray (assembled , CEED_MEM_DEVICE , & assembled_array ));
437
+
438
+ // Assemble diagonal
439
+ void * opargs [] = {(void * )& num_elem , & qf_data -> d_c , & data -> indices , & data -> fields , & data -> B , & data -> G , & data -> W , & data -> points , & assembled_array };
440
+ int max_threads_per_block , min_grid_size , grid ;
441
+
442
+ CeedCallCuda (ceed , cuOccupancyMaxPotentialBlockSize (& min_grid_size , & max_threads_per_block , data -> op , dynamicSMemSize , 0 , 0x10000 ));
443
+ int block [3 ] = {data -> thread_1d , (data -> dim == 1 ? 1 : data -> thread_1d ), -1 };
444
+
445
+ CeedCallBackend (BlockGridCalculate (num_elem , min_grid_size / cuda_data -> device_prop .multiProcessorCount , 1 ,
446
+ cuda_data -> device_prop .maxThreadsDim [2 ], cuda_data -> device_prop .warpSize , block , & grid ));
447
+ CeedInt shared_mem = block [0 ] * block [1 ] * block [2 ] * sizeof (CeedScalar );
448
+
449
+ CeedCallBackend (
450
+ CeedTryRunKernelDimShared_Cuda (ceed , data -> assemble_diagonal , NULL , grid , block [0 ], block [1 ], block [2 ], shared_mem , & is_run_good , opargs ));
451
+
452
+ // Restore input arrays
453
+ for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
454
+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
455
+ if (eval_mode == CEED_EVAL_WEIGHT ) { // Skip
456
+ } else {
457
+ bool is_active ;
458
+ CeedVector vec ;
459
+
460
+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & vec ));
461
+ is_active = vec == CEED_VECTOR_ACTIVE ;
462
+ if (!is_active ) CeedCallBackend (CeedVectorRestoreArrayRead (vec , & data -> fields .inputs [i ]));
463
+ CeedCallBackend (CeedVectorDestroy (& vec ));
464
+ }
465
+ }
466
+
467
+ // Restore point coordinates
468
+ {
469
+ CeedVector vec ;
470
+
471
+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , NULL , & vec ));
472
+ CeedCallBackend (CeedVectorRestoreArrayRead (vec , & data -> points .coords ));
473
+ CeedCallBackend (CeedVectorDestroy (& vec ));
474
+ }
475
+
476
+ // Restore context data
477
+ CeedCallBackend (CeedQFunctionRestoreInnerContextData (qf , & qf_data -> d_c ));
478
+
479
+ // Restore assembly array
480
+ CeedCallBackend (CeedVectorRestoreArray (assembled , & assembled_array ));
481
+
482
+ // Cleanup
483
+ CeedCallBackend (CeedQFunctionDestroy (& qf ));
484
+ if (!is_run_good ) data -> use_assembly_fallback = true;
485
+ }
486
+ CeedCallBackend (CeedDestroy (& ceed ));
487
+
488
+ // Fallback, if needed
489
+ if (data -> use_assembly_fallback ) {
490
+ CeedOperator op_fallback ;
491
+
492
+ CeedDebug256 (CeedOperatorReturnCeed (op ), CEED_DEBUG_COLOR_SUCCESS , "Falling back to /gpu/cuda/ref CeedOperator" );
493
+ CeedCallBackend (CeedOperatorGetFallback (op , & op_fallback ));
494
+ CeedCallBackend (CeedOperatorLinearAssembleAddDiagonal (op_fallback , assembled , request ));
495
+ return CEED_ERROR_SUCCESS ;
496
+ }
497
+ return CEED_ERROR_SUCCESS ;
498
+ }
499
+
336
500
//------------------------------------------------------------------------------
337
501
// Create operator
338
502
//------------------------------------------------------------------------------
339
503
int CeedOperatorCreate_Cuda_gen (CeedOperator op ) {
340
- bool is_composite ;
504
+ bool is_composite , is_at_points ;
341
505
Ceed ceed ;
342
506
CeedOperator_Cuda_gen * impl ;
343
507
@@ -350,6 +514,11 @@ int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
350
514
} else {
351
515
CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Cuda_gen ));
352
516
}
517
+ CeedCall (CeedOperatorIsAtPoints (op , & is_at_points ));
518
+ if (is_at_points ) {
519
+ CeedCallBackend (
520
+ CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleAddDiagonal" , CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen ));
521
+ }
353
522
CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Cuda_gen ));
354
523
CeedCallBackend (CeedDestroy (& ceed ));
355
524
return CEED_ERROR_SUCCESS ;
0 commit comments