Skip to content

Commit 0183ed6

Browse files
GPU Assembly AtPoints (#1833)
* cuda - AtPoints diagonal assembly for gen * hip - AtPoints diagonal assembly for gen * pc - use subops for LinearAssemble[Add]Diagonal if composite * gen - turn more numbers into named variables * gen - fix alignment for assembly * gen - check for only one active basis in/out * HIP gen at points syntax error fixes * hip - embarassing fix * gen - add Tab helper to manage indentation --------- Co-authored-by: Zach Atkins <[email protected]>
1 parent d6c19ee commit 0183ed6

13 files changed

+2083
-519
lines changed

backends/cuda-gen/ceed-cuda-gen-operator-build.cpp

Lines changed: 752 additions & 249 deletions
Large diffs are not rendered by default.

backends/cuda-gen/ceed-cuda-gen-operator-build.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
#pragma once
88

99
CEED_INTERN int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op, bool *is_good_build);
10+
CEED_INTERN int CeedOperatorBuildKernelFullAssemblyAtPoints_Cuda_gen(CeedOperator op, bool *is_good_build);
11+
CEED_INTERN int CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Cuda_gen(CeedOperator op, bool *is_good_build);

backends/cuda-gen/ceed-cuda-gen-operator.c

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ static int CeedOperatorDestroy_Cuda_gen(CeedOperator op) {
2828
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2929
CeedCallBackend(CeedOperatorGetData(op, &impl));
3030
if (impl->module) CeedCallCuda(ceed, cuModuleUnload(impl->module));
31+
if (impl->module_assemble_full) CeedCallCuda(ceed, cuModuleUnload(impl->module_assemble_full));
32+
if (impl->module_assemble_diagonal) CeedCallCuda(ceed, cuModuleUnload(impl->module_assemble_diagonal));
3133
if (impl->points.num_per_elem) CeedCallCuda(ceed, cudaFree((void **)impl->points.num_per_elem));
3234
CeedCallBackend(CeedFree(&impl));
3335
CeedCallBackend(CeedDestroy(&ceed));
@@ -333,11 +335,173 @@ static int CeedOperatorApplyAddComposite_Cuda_gen(CeedOperator op, CeedVector in
333335
return CEED_ERROR_SUCCESS;
334336
}
335337

338+
//------------------------------------------------------------------------------
339+
// AtPoints diagonal assembly
340+
//------------------------------------------------------------------------------
341+
static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen(CeedOperator op, CeedVector assembled, CeedRequest *request) {
342+
Ceed ceed;
343+
CeedOperator_Cuda_gen *data;
344+
345+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
346+
CeedCallBackend(CeedOperatorGetData(op, &data));
347+
348+
// Build the assembly kernel
349+
if (!data->assemble_diagonal && !data->use_assembly_fallback) {
350+
bool is_build_good = false;
351+
CeedInt num_active_bases_in, num_active_bases_out;
352+
CeedOperatorAssemblyData assembly_data;
353+
354+
CeedCallBackend(CeedOperatorGetOperatorAssemblyData(op, &assembly_data));
355+
CeedCallBackend(
356+
CeedOperatorAssemblyDataGetEvalModes(assembly_data, &num_active_bases_in, NULL, NULL, NULL, &num_active_bases_out, NULL, NULL, NULL, NULL));
357+
if (num_active_bases_in == num_active_bases_out) {
358+
CeedCallBackend(CeedOperatorBuildKernel_Cuda_gen(op, &is_build_good));
359+
if (is_build_good) CeedCallBackend(CeedOperatorBuildKernelDiagonalAssemblyAtPoints_Cuda_gen(op, &is_build_good));
360+
}
361+
if (!is_build_good) data->use_assembly_fallback = true;
362+
}
363+
364+
// Try assembly
365+
if (!data->use_assembly_fallback) {
366+
bool is_run_good = true;
367+
Ceed_Cuda *cuda_data;
368+
CeedInt num_elem, num_input_fields, num_output_fields;
369+
CeedEvalMode eval_mode;
370+
CeedScalar *assembled_array;
371+
CeedQFunctionField *qf_input_fields, *qf_output_fields;
372+
CeedQFunction_Cuda_gen *qf_data;
373+
CeedQFunction qf;
374+
CeedOperatorField *op_input_fields, *op_output_fields;
375+
376+
CeedCallBackend(CeedGetData(ceed, &cuda_data));
377+
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
378+
CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
379+
CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
380+
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
381+
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
382+
383+
// Input vectors
384+
for (CeedInt i = 0; i < num_input_fields; i++) {
385+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
386+
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
387+
data->fields.inputs[i] = NULL;
388+
} else {
389+
bool is_active;
390+
CeedVector vec;
391+
392+
// Get input vector
393+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
394+
is_active = vec == CEED_VECTOR_ACTIVE;
395+
if (is_active) data->fields.inputs[i] = NULL;
396+
else CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
397+
CeedCallBackend(CeedVectorDestroy(&vec));
398+
}
399+
}
400+
401+
// Point coordinates
402+
{
403+
CeedVector vec;
404+
405+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
406+
CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
407+
CeedCallBackend(CeedVectorDestroy(&vec));
408+
409+
// Points per elem
410+
if (num_elem != data->points.num_elem) {
411+
CeedInt *points_per_elem;
412+
const CeedInt num_bytes = num_elem * sizeof(CeedInt);
413+
CeedElemRestriction rstr_points = NULL;
414+
415+
data->points.num_elem = num_elem;
416+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
417+
CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
418+
for (CeedInt e = 0; e < num_elem; e++) {
419+
CeedInt num_points_elem;
420+
421+
CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
422+
points_per_elem[e] = num_points_elem;
423+
}
424+
if (data->points.num_per_elem) CeedCallCuda(ceed, cudaFree((void **)data->points.num_per_elem));
425+
CeedCallCuda(ceed, cudaMalloc((void **)&data->points.num_per_elem, num_bytes));
426+
CeedCallCuda(ceed, cudaMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, cudaMemcpyHostToDevice));
427+
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
428+
CeedCallBackend(CeedFree(&points_per_elem));
429+
}
430+
}
431+
432+
// Get context data
433+
CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
434+
435+
// Assembly array
436+
CeedCallBackend(CeedVectorGetArray(assembled, CEED_MEM_DEVICE, &assembled_array));
437+
438+
// Assemble diagonal
439+
void *opargs[] = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points, &assembled_array};
440+
int max_threads_per_block, min_grid_size, grid;
441+
442+
CeedCallCuda(ceed, cuOccupancyMaxPotentialBlockSize(&min_grid_size, &max_threads_per_block, data->op, dynamicSMemSize, 0, 0x10000));
443+
int block[3] = {data->thread_1d, (data->dim == 1 ? 1 : data->thread_1d), -1};
444+
445+
CeedCallBackend(BlockGridCalculate(num_elem, min_grid_size / cuda_data->device_prop.multiProcessorCount, 1,
446+
cuda_data->device_prop.maxThreadsDim[2], cuda_data->device_prop.warpSize, block, &grid));
447+
CeedInt shared_mem = block[0] * block[1] * block[2] * sizeof(CeedScalar);
448+
449+
CeedCallBackend(
450+
CeedTryRunKernelDimShared_Cuda(ceed, data->assemble_diagonal, NULL, grid, block[0], block[1], block[2], shared_mem, &is_run_good, opargs));
451+
452+
// Restore input arrays
453+
for (CeedInt i = 0; i < num_input_fields; i++) {
454+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
455+
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
456+
} else {
457+
bool is_active;
458+
CeedVector vec;
459+
460+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
461+
is_active = vec == CEED_VECTOR_ACTIVE;
462+
if (!is_active) CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
463+
CeedCallBackend(CeedVectorDestroy(&vec));
464+
}
465+
}
466+
467+
// Restore point coordinates
468+
{
469+
CeedVector vec;
470+
471+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
472+
CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
473+
CeedCallBackend(CeedVectorDestroy(&vec));
474+
}
475+
476+
// Restore context data
477+
CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
478+
479+
// Restore assembly array
480+
CeedCallBackend(CeedVectorRestoreArray(assembled, &assembled_array));
481+
482+
// Cleanup
483+
CeedCallBackend(CeedQFunctionDestroy(&qf));
484+
if (!is_run_good) data->use_assembly_fallback = true;
485+
}
486+
CeedCallBackend(CeedDestroy(&ceed));
487+
488+
// Fallback, if needed
489+
if (data->use_assembly_fallback) {
490+
CeedOperator op_fallback;
491+
492+
CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/cuda/ref CeedOperator");
493+
CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
494+
CeedCallBackend(CeedOperatorLinearAssembleAddDiagonal(op_fallback, assembled, request));
495+
return CEED_ERROR_SUCCESS;
496+
}
497+
return CEED_ERROR_SUCCESS;
498+
}
499+
336500
//------------------------------------------------------------------------------
337501
// Create operator
338502
//------------------------------------------------------------------------------
339503
int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
340-
bool is_composite;
504+
bool is_composite, is_at_points;
341505
Ceed ceed;
342506
CeedOperator_Cuda_gen *impl;
343507

@@ -350,6 +514,11 @@ int CeedOperatorCreate_Cuda_gen(CeedOperator op) {
350514
} else {
351515
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda_gen));
352516
}
517+
CeedCall(CeedOperatorIsAtPoints(op, &is_at_points));
518+
if (is_at_points) {
519+
CeedCallBackend(
520+
CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda_gen));
521+
}
353522
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda_gen));
354523
CeedCallBackend(CeedDestroy(&ceed));
355524
return CEED_ERROR_SUCCESS;

backends/cuda-gen/ceed-cuda-gen.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
#include <cuda.h>
1313

1414
typedef struct {
15-
bool use_fallback;
15+
bool use_fallback, use_assembly_fallback;
1616
CeedInt dim;
1717
CeedInt Q, Q_1d;
1818
CeedInt max_P_1d;
1919
CeedInt thread_1d;
20-
CUmodule module;
21-
CUfunction op;
20+
CUmodule module, module_assemble_full, module_assemble_diagonal;
21+
CUfunction op, assemble_full, assemble_diagonal;
2222
FieldsInt_Cuda indices;
2323
Fields_Cuda fields;
2424
Fields_Cuda B;

0 commit comments

Comments
 (0)