From ff4acf143d4c458a12847fadc627afb733f0bc3c Mon Sep 17 00:00:00 2001 From: Zach Atkins Date: Mon, 27 Jan 2025 09:57:50 -0700 Subject: [PATCH] Vector API compliance for CUDA backends --- backends/cuda-ref/ceed-cuda-ref-basis.c | 58 ++++++------------- backends/cuda-shared/ceed-cuda-shared-basis.c | 20 ++----- 2 files changed, 24 insertions(+), 54 deletions(-) diff --git a/backends/cuda-ref/ceed-cuda-ref-basis.c b/backends/cuda-ref/ceed-cuda-ref-basis.c index 544a5cb188..b21466f33c 100644 --- a/backends/cuda-ref/ceed-cuda-ref-basis.c +++ b/backends/cuda-ref/ceed-cuda-ref-basis.c @@ -35,20 +35,12 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn // Get read/write access to u, v if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); - if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); - else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); - - // Clear v for transpose operation - if (is_transpose && !apply_add) { - CeedInt num_comp, q_comp, num_nodes, num_qpts; - CeedSize length; - - CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); - CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); - CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); - CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts)); - length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp)); - CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); + if (apply_add) { + CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); + } else { + // Clear v for transpose operation + if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0)); + CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); } CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); CeedCallBackend(CeedBasisGetDimension(basis, &dim)); @@ -203,20 +195,12 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x)); if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); - if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); - else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); - - // Clear v for transpose operation - if (is_transpose && !apply_add) { - CeedInt num_comp, q_comp, num_nodes; - CeedSize length; - - CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); - CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); - CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); - length = - (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp)); - CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); + if (apply_add) { + CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); + } else { + // Clear v for transpose operation + if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0)); + CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); } // Basis action @@ -287,18 +271,12 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con // Get read/write access to u, v if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); - if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); - else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); - - // Clear v for transpose operation - if (is_transpose && !apply_add) { - CeedInt num_comp, q_comp; - CeedSize length; - - CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); - CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); - length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp)); - CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); + if (apply_add) { + CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); + } else { + // Clear v for transpose operation + if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0)); + CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); } // Apply basis operation diff --git a/backends/cuda-shared/ceed-cuda-shared-basis.c b/backends/cuda-shared/ceed-cuda-shared-basis.c index f3a73b54ba..06fd102f11 100644 --- a/backends/cuda-shared/ceed-cuda-shared-basis.c +++ b/backends/cuda-shared/ceed-cuda-shared-basis.c @@ -297,20 +297,12 @@ static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_ad CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x)); if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); - if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); - else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); - - // Clear v for transpose operation - if (is_transpose && !apply_add) { - CeedInt num_comp, q_comp, num_nodes; - CeedSize length; - - CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); - CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); - CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); - length = - (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp)); - CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); + if (apply_add) { + CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); + } else { + // Clear v for transpose operation + if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0)); + CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); } // Basis action