@@ -448,6 +448,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen(CeedOperator o
448
448
449
449
CeedCallBackend (
450
450
CeedTryRunKernelDimShared_Cuda (ceed , data -> assemble_diagonal , NULL , grid , block [0 ], block [1 ], block [2 ], shared_mem , & is_run_good , opargs ));
451
+ CeedCallCuda (ceed , cudaDeviceSynchronize ());
451
452
452
453
// Restore input arrays
453
454
for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
@@ -497,6 +498,171 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen(CeedOperator o
497
498
return CEED_ERROR_SUCCESS ;
498
499
}
499
500
501
+ //------------------------------------------------------------------------------
502
+ // AtPoints full assembly
503
+ //------------------------------------------------------------------------------
504
+ static int CeedSingleOperatorAssembleAtPoints_Cuda_gen (CeedOperator op , CeedInt offset , CeedVector assembled ) {
505
+ Ceed ceed ;
506
+ CeedOperator_Cuda_gen * data ;
507
+
508
+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
509
+ CeedCallBackend (CeedOperatorGetData (op , & data ));
510
+
511
+ // Build the assembly kernel
512
+ if (!data -> assemble_full && !data -> use_assembly_fallback ) {
513
+ bool is_build_good = false;
514
+ CeedInt num_active_bases_in , num_active_bases_out ;
515
+ CeedOperatorAssemblyData assembly_data ;
516
+
517
+ CeedCallBackend (CeedOperatorGetOperatorAssemblyData (op , & assembly_data ));
518
+ CeedCallBackend (
519
+ CeedOperatorAssemblyDataGetEvalModes (assembly_data , & num_active_bases_in , NULL , NULL , NULL , & num_active_bases_out , NULL , NULL , NULL , NULL ));
520
+ if (num_active_bases_in == num_active_bases_out ) {
521
+ CeedCallBackend (CeedOperatorBuildKernel_Cuda_gen (op , & is_build_good ));
522
+ if (is_build_good ) CeedCallBackend (CeedOperatorBuildKernelFullAssemblyAtPoints_Cuda_gen (op , & is_build_good ));
523
+ }
524
+ if (!is_build_good ) data -> use_assembly_fallback = true;
525
+ }
526
+
527
+ // Try assembly
528
+ if (!data -> use_assembly_fallback ) {
529
+ bool is_run_good = true;
530
+ Ceed_Cuda * cuda_data ;
531
+ CeedInt num_elem , num_input_fields , num_output_fields ;
532
+ CeedEvalMode eval_mode ;
533
+ CeedScalar * assembled_array ;
534
+ CeedQFunctionField * qf_input_fields , * qf_output_fields ;
535
+ CeedQFunction_Cuda_gen * qf_data ;
536
+ CeedQFunction qf ;
537
+ CeedOperatorField * op_input_fields , * op_output_fields ;
538
+
539
+ CeedCallBackend (CeedGetData (ceed , & cuda_data ));
540
+ CeedCallBackend (CeedOperatorGetQFunction (op , & qf ));
541
+ CeedCallBackend (CeedQFunctionGetData (qf , & qf_data ));
542
+ CeedCallBackend (CeedOperatorGetNumElements (op , & num_elem ));
543
+ CeedCallBackend (CeedOperatorGetFields (op , & num_input_fields , & op_input_fields , & num_output_fields , & op_output_fields ));
544
+ CeedCallBackend (CeedQFunctionGetFields (qf , NULL , & qf_input_fields , NULL , & qf_output_fields ));
545
+
546
+ // Input vectors
547
+ for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
548
+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
549
+ if (eval_mode == CEED_EVAL_WEIGHT ) { // Skip
550
+ data -> fields .inputs [i ] = NULL ;
551
+ } else {
552
+ bool is_active ;
553
+ CeedVector vec ;
554
+
555
+ // Get input vector
556
+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & vec ));
557
+ is_active = vec == CEED_VECTOR_ACTIVE ;
558
+ if (is_active ) data -> fields .inputs [i ] = NULL ;
559
+ else CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & data -> fields .inputs [i ]));
560
+ CeedCallBackend (CeedVectorDestroy (& vec ));
561
+ }
562
+ }
563
+
564
+ // Point coordinates
565
+ {
566
+ CeedVector vec ;
567
+
568
+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , NULL , & vec ));
569
+ CeedCallBackend (CeedVectorGetArrayRead (vec , CEED_MEM_DEVICE , & data -> points .coords ));
570
+ CeedCallBackend (CeedVectorDestroy (& vec ));
571
+
572
+ // Points per elem
573
+ if (num_elem != data -> points .num_elem ) {
574
+ CeedInt * points_per_elem ;
575
+ const CeedInt num_bytes = num_elem * sizeof (CeedInt );
576
+ CeedElemRestriction rstr_points = NULL ;
577
+
578
+ data -> points .num_elem = num_elem ;
579
+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , & rstr_points , NULL ));
580
+ CeedCallBackend (CeedCalloc (num_elem , & points_per_elem ));
581
+ for (CeedInt e = 0 ; e < num_elem ; e ++ ) {
582
+ CeedInt num_points_elem ;
583
+
584
+ CeedCallBackend (CeedElemRestrictionGetNumPointsInElement (rstr_points , e , & num_points_elem ));
585
+ points_per_elem [e ] = num_points_elem ;
586
+ }
587
+ if (data -> points .num_per_elem ) CeedCallCuda (ceed , cudaFree ((void * * )data -> points .num_per_elem ));
588
+ CeedCallCuda (ceed , cudaMalloc ((void * * )& data -> points .num_per_elem , num_bytes ));
589
+ CeedCallCuda (ceed , cudaMemcpy ((void * )data -> points .num_per_elem , points_per_elem , num_bytes , cudaMemcpyHostToDevice ));
590
+ CeedCallBackend (CeedElemRestrictionDestroy (& rstr_points ));
591
+ CeedCallBackend (CeedFree (& points_per_elem ));
592
+ }
593
+ }
594
+
595
+ // Get context data
596
+ CeedCallBackend (CeedQFunctionGetInnerContextData (qf , CEED_MEM_DEVICE , & qf_data -> d_c ));
597
+
598
+ // Assembly array
599
+ CeedCallBackend (CeedVectorGetArray (assembled , CEED_MEM_DEVICE , & assembled_array ));
600
+ CeedScalar * assembled_offset_array = & assembled_array [offset ];
601
+
602
+ // Assemble diagonal
603
+ void * opargs [] = {(void * )& num_elem , & qf_data -> d_c , & data -> indices , & data -> fields , & data -> B ,
604
+ & data -> G , & data -> W , & data -> points , & assembled_offset_array };
605
+ int max_threads_per_block , min_grid_size , grid ;
606
+
607
+ CeedCallCuda (ceed , cuOccupancyMaxPotentialBlockSize (& min_grid_size , & max_threads_per_block , data -> op , dynamicSMemSize , 0 , 0x10000 ));
608
+ int block [3 ] = {data -> thread_1d , (data -> dim == 1 ? 1 : data -> thread_1d ), -1 };
609
+
610
+ CeedCallBackend (BlockGridCalculate (num_elem , min_grid_size / cuda_data -> device_prop .multiProcessorCount , 1 ,
611
+ cuda_data -> device_prop .maxThreadsDim [2 ], cuda_data -> device_prop .warpSize , block , & grid ));
612
+ CeedInt shared_mem = block [0 ] * block [1 ] * block [2 ] * sizeof (CeedScalar );
613
+
614
+ CeedCallBackend (
615
+ CeedTryRunKernelDimShared_Cuda (ceed , data -> assemble_full , NULL , grid , block [0 ], block [1 ], block [2 ], shared_mem , & is_run_good , opargs ));
616
+ CeedCallCuda (ceed , cudaDeviceSynchronize ());
617
+
618
+ // Restore input arrays
619
+ for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
620
+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
621
+ if (eval_mode == CEED_EVAL_WEIGHT ) { // Skip
622
+ } else {
623
+ bool is_active ;
624
+ CeedVector vec ;
625
+
626
+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & vec ));
627
+ is_active = vec == CEED_VECTOR_ACTIVE ;
628
+ if (!is_active ) CeedCallBackend (CeedVectorRestoreArrayRead (vec , & data -> fields .inputs [i ]));
629
+ CeedCallBackend (CeedVectorDestroy (& vec ));
630
+ }
631
+ }
632
+
633
+ // Restore point coordinates
634
+ {
635
+ CeedVector vec ;
636
+
637
+ CeedCallBackend (CeedOperatorAtPointsGetPoints (op , NULL , & vec ));
638
+ CeedCallBackend (CeedVectorRestoreArrayRead (vec , & data -> points .coords ));
639
+ CeedCallBackend (CeedVectorDestroy (& vec ));
640
+ }
641
+
642
+ // Restore context data
643
+ CeedCallBackend (CeedQFunctionRestoreInnerContextData (qf , & qf_data -> d_c ));
644
+
645
+ // Restore assembly array
646
+ CeedCallBackend (CeedVectorRestoreArray (assembled , & assembled_array ));
647
+
648
+ // Cleanup
649
+ CeedCallBackend (CeedQFunctionDestroy (& qf ));
650
+ if (!is_run_good ) data -> use_assembly_fallback = true;
651
+ }
652
+ CeedCallBackend (CeedDestroy (& ceed ));
653
+
654
+ // Fallback, if needed
655
+ if (data -> use_assembly_fallback ) {
656
+ CeedOperator op_fallback ;
657
+
658
+ CeedDebug256 (CeedOperatorReturnCeed (op ), CEED_DEBUG_COLOR_SUCCESS , "Falling back to /gpu/cuda/ref CeedOperator" );
659
+ CeedCallBackend (CeedOperatorGetFallback (op , & op_fallback ));
660
+ CeedCallBackend (CeedSingleOperatorAssemble (op_fallback , offset , assembled ));
661
+ return CEED_ERROR_SUCCESS ;
662
+ }
663
+ return CEED_ERROR_SUCCESS ;
664
+ }
665
+
500
666
//------------------------------------------------------------------------------
501
667
// Create operator
502
668
//------------------------------------------------------------------------------
@@ -518,6 +684,7 @@ int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
518
684
if (is_at_points ) {
519
685
CeedCallBackend (
520
686
CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleAddDiagonal" , CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen ));
687
+ CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingle" , CeedSingleOperatorAssembleAtPoints_Cuda_gen ));
521
688
}
522
689
CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Cuda_gen ));
523
690
CeedCallBackend (CeedDestroy (& ceed ));
0 commit comments