Skip to content

Commit 6001dc9

Browse files
authored
Merge pull request #1839 from CEED/zach/gen-full-assembly-at-points
GPU Gen Full Assembly at Points
2 parents 0183ed6 + a34b87f commit 6001dc9

File tree

7 files changed

+545
-14
lines changed

7 files changed

+545
-14
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1995,7 +1995,19 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(CeedOperator op, boo
19951995

19961996
// ---- Restriction
19971997
if (is_full) {
1998-
// TODO: UPDATE OUTPUTS FOR FULL ASSEMBLY
1998+
std::string var_suffix = "_out_" + std::to_string(i);
1999+
CeedInt comp_stride;
2000+
CeedSize l_size;
2001+
CeedElemRestriction elem_rstr;
2002+
2003+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
2004+
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
2005+
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
2006+
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
2007+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
2008+
code << tab << "WriteLVecStandard" << max_dim << "d_Assembly<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", P_1d" + var_suffix
2009+
<< ">(data, l_size" << var_suffix << ", elem, n, r_e" << var_suffix << ", values_array);\n";
2010+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
19992011
} else {
20002012
std::string var_suffix = "_out_" + std::to_string(i);
20012013
CeedInt comp_stride;
@@ -2055,4 +2067,8 @@ extern "C" int CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Cuda_gen(CeedOper
20552067
return CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(op, false, is_good_build);
20562068
}
20572069

2070+
extern "C" int CeedOperatorBuildKernelFullAssemblyAtPoints_Cuda_gen(CeedOperator op, bool *is_good_build) {
2071+
return CeedOperatorBuildKernelAssemblyAtPoints_Cuda_gen(op, true, is_good_build);
2072+
}
2073+
20582074
//------------------------------------------------------------------------------

backends/cuda-gen/ceed-cuda-gen-operator.c

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen(CeedOperator o
448448

449449
CeedCallBackend(
450450
CeedTryRunKernelDimShared_Cuda(ceed, data->assemble_diagonal, NULL, grid, block[0], block[1], block[2], shared_mem, &is_run_good, opargs));
451+
CeedCallCuda(ceed, cudaDeviceSynchronize());
451452

452453
// Restore input arrays
453454
for (CeedInt i = 0; i < num_input_fields; i++) {
@@ -497,6 +498,171 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen(CeedOperator o
497498
return CEED_ERROR_SUCCESS;
498499
}
499500

501+
//------------------------------------------------------------------------------
502+
// AtPoints full assembly
503+
//------------------------------------------------------------------------------
504+
static int CeedSingleOperatorAssembleAtPoints_Cuda_gen(CeedOperator op, CeedInt offset, CeedVector assembled) {
505+
Ceed ceed;
506+
CeedOperator_Cuda_gen *data;
507+
508+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
509+
CeedCallBackend(CeedOperatorGetData(op, &data));
510+
511+
// Build the assembly kernel
512+
if (!data->assemble_full && !data->use_assembly_fallback) {
513+
bool is_build_good = false;
514+
CeedInt num_active_bases_in, num_active_bases_out;
515+
CeedOperatorAssemblyData assembly_data;
516+
517+
CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
518+
CeedCallBackend(
519+
CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL));
520+
if (num_active_bases_in == num_active_bases_out) {
521+
CeedCallBackend(CeedOperatorBuildKernel_Cuda_gen(op, &is_build_good));
522+
if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelFullAssemblyAtPoints_Cuda_gen(op, &is_build_good));
523+
}
524+
if (!is_build_good) data->use_assembly_fallback = true;
525+
}
526+
527+
// Try assembly
528+
if (!data->use_assembly_fallback) {
529+
bool is_run_good = true;
530+
Ceed_Cuda *cuda_data;
531+
CeedInt num_elem, num_input_fields, num_output_fields;
532+
CeedEvalMode eval_mode;
533+
CeedScalar *assembled_array;
534+
CeedQFunctionField *qf_input_fields, *qf_output_fields;
535+
CeedQFunction_Cuda_gen *qf_data;
536+
CeedQFunction qf;
537+
CeedOperatorField *op_input_fields, *op_output_fields;
538+
539+
CeedCallBackend(CeedGetData(ceed, &cuda_data));
540+
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
541+
CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
542+
CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
543+
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
544+
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
545+
546+
// Input vectors
547+
for (CeedInt i = 0; i < num_input_fields; i++) {
548+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
549+
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
550+
data->fields.inputs[i] = NULL;
551+
} else {
552+
bool is_active;
553+
CeedVector vec;
554+
555+
// Get input vector
556+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
557+
is_active = vec == CEED_VECTOR_ACTIVE;
558+
if (is_active) data->fields.inputs[i] = NULL;
559+
else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
560+
CeedCallBackend(CeedVectorDestroy(&vec));
561+
}
562+
}
563+
564+
// Point coordinates
565+
{
566+
CeedVector vec;
567+
568+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
569+
CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
570+
CeedCallBackend(CeedVectorDestroy(&vec));
571+
572+
// Points per elem
573+
if (num_elem != data->points.num_elem) {
574+
CeedInt *points_per_elem;
575+
const CeedInt num_bytes = num_elem * sizeof(CeedInt);
576+
CeedElemRestriction rstr_points = NULL;
577+
578+
data->points.num_elem = num_elem;
579+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
580+
CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
581+
for (CeedInt e = 0; e < num_elem; e++) {
582+
CeedInt num_points_elem;
583+
584+
CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
585+
points_per_elem[e] = num_points_elem;
586+
}
587+
if (data->points.num_per_elem) CeedCallCuda(ceed, cudaFree((void **)data->points.num_per_elem));
588+
CeedCallCuda(ceed, cudaMalloc((void **)&data->points.num_per_elem, num_bytes));
589+
CeedCallCuda(ceed, cudaMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, cudaMemcpyHostToDevice));
590+
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
591+
CeedCallBackend(CeedFree(&points_per_elem));
592+
}
593+
}
594+
595+
// Get context data
596+
CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
597+
598+
// Assembly array
599+
CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
600+
CeedScalar *assembled_offset_array = &assembled_array[offset];
601+
602+
// Assemble diagonal
603+
void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B,
604+
&data->G, &data->W, &data->points, &assembled_offset_array};
605+
int max_threads_per_block, min_grid_size, grid;
606+
607+
CeedCallCuda(ceed, cuOccupancyMaxPotentialBlockSize(&min_grid_size, &max_threads_per_block, data->op, dynamicSMemSize, 0, 0x10000));
608+
int block[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
609+
610+
CeedCallBackend(BlockGridCalculate(num_elem, min_grid_size / cuda_data->device_prop.multiProcessorCount, 1,
611+
cuda_data->device_prop.maxThreadsDim[2], cuda_data->device_prop.warpSize, block, &grid));
612+
CeedInt shared_mem = block[0] * block[1] * block[2] * sizeof(CeedScalar);
613+
614+
CeedCallBackend(
615+
CeedTryRunKernelDimShared_Cuda(ceed, data->assemble_full, NULL, grid, block[0], block[1], block[2], shared_mem, &is_run_good, opargs));
616+
CeedCallCuda(ceed, cudaDeviceSynchronize());
617+
618+
// Restore input arrays
619+
for (CeedInt i = 0; i < num_input_fields; i++) {
620+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
621+
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
622+
} else {
623+
bool is_active;
624+
CeedVector vec;
625+
626+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
627+
is_active = vec == CEED_VECTOR_ACTIVE;
628+
if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
629+
CeedCallBackend(CeedVectorDestroy(&vec));
630+
}
631+
}
632+
633+
// Restore point coordinates
634+
{
635+
CeedVector vec;
636+
637+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
638+
CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
639+
CeedCallBackend(CeedVectorDestroy(&vec));
640+
}
641+
642+
// Restore context data
643+
CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
644+
645+
// Restore assembly array
646+
CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
647+
648+
// Cleanup
649+
CeedCallBackend(CeedQFunctionDestroy(&qf));
650+
if (!is_run_good) data->use_assembly_fallback = true;
651+
}
652+
CeedCallBackend(CeedDestroy(&ceed));
653+
654+
// Fallback, if needed
655+
if (data->use_assembly_fallback) {
656+
CeedOperator op_fallback;
657+
658+
CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/cuda/ref CeedOperator");
659+
CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
660+
CeedCallBackend(CeedSingleOperatorAssemble(op_fallback, offset, assembled));
661+
return CEED_ERROR_SUCCESS;
662+
}
663+
return CEED_ERROR_SUCCESS;
664+
}
665+
500666
//------------------------------------------------------------------------------
501667
// Create operator
502668
//------------------------------------------------------------------------------
@@ -518,6 +684,7 @@ int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
518684
if (is_at_points) {
519685
CeedCallBackend(
520686
CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen));
687+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssembleAtPoints_Cuda_gen));
521688
}
522689
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda_gen));
523690
CeedCallBackend(CeedDestroy(&ceed));

backends/hip-gen/ceed-hip-gen-operator-build.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2001,7 +2001,19 @@ static int CeedOperatorBuildKernelAssemblyAtPoints_Hip_gen(CeedOperator op, bool
20012001

20022002
// ---- Restriction
20032003
if (is_full) {
2004-
// TODO: UPDATE OUTPUTS FOR FULL ASSEMBLY
2004+
std::string var_suffix = "_out_" + std::to_string(i);
2005+
CeedInt comp_stride;
2006+
CeedSize l_size;
2007+
CeedElemRestriction elem_rstr;
2008+
2009+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
2010+
CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
2011+
code << tab << "const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
2012+
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
2013+
code << tab << "const CeedInt comp_stride" << var_suffix << " = " << comp_stride << ";\n";
2014+
code << tab << "WriteLVecStandard" << max_dim << "d_Assembly<num_comp" << var_suffix << ", comp_stride" << var_suffix << ", P_1d" + var_suffix
2015+
<< ">(data, l_size" << var_suffix << ", elem, n, r_e" << var_suffix << ", values_array);\n";
2016+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
20052017
} else {
20062018
std::string var_suffix = "_out_" + std::to_string(i);
20072019
CeedInt comp_stride;
@@ -2067,4 +2079,8 @@ extern "C" int CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Hip_gen(CeedOpera
20672079
return CeedOperatorBuildKernelAssemblyAtPoints_Hip_gen(op, false, is_good_build);
20682080
}
20692081

2082+
extern "C" int CeedOperatorBuildKernelFullAssemblyAtPoints_Hip_gen(CeedOperator op, bool *is_good_build) {
2083+
return CeedOperatorBuildKernelAssemblyAtPoints_Hip_gen(op, true, is_good_build);
2084+
}
2085+
20702086
//------------------------------------------------------------------------------

0 commit comments

Comments
 (0)