Skip to content

Commit

Permalink
Vector API compliance for CUDA backends
Browse files Browse the repository at this point in the history
  • Loading branch information
zatkins-dev committed Jan 27, 2025
1 parent 28c41a4 commit ff4acf1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 54 deletions.
58 changes: 18 additions & 40 deletions backends/cuda-ref/ceed-cuda-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,12 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn
// Get read/write access to u, v
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes, num_qpts;
CeedSize length;

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
if (apply_add) {
CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
} else {
// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}
CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
CeedCallBackend(CeedBasisGetDimension(basis, &dim));
Expand Down Expand Up @@ -203,20 +195,12 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons
CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x));
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
if (apply_add) {
CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
} else {
// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

// Basis action
Expand Down Expand Up @@ -287,18 +271,12 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con
// Get read/write access to u, v
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp;
CeedSize length;

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
if (apply_add) {
CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
} else {
// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

// Apply basis operation
Expand Down
20 changes: 6 additions & 14 deletions backends/cuda-shared/ceed-cuda-shared-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,12 @@ static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_ad
CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x));
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
if (apply_add) {
CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
} else {
// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

// Basis action
Expand Down

0 comments on commit ff4acf1

Please sign in to comment.