Skip to content

Commit

Permalink
hip - nontensor gen operators
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Jan 30, 2025
1 parent dc007f0 commit 9123fb0
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 112 deletions.
6 changes: 3 additions & 3 deletions backends/cuda-gen/ceed-cuda-gen-operator-build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie

// Collect dim, P_1d, and Q_1d
CeedCallBackend(CeedBasisIsTensor(basis, &is_field_tensor));
*is_tensor = *is_tensor && is_field_tensor;
*is_tensor = *is_tensor && is_field_tensor;
if (is_field_tensor) CeedCallBackend(CeedBasisGetNumNodes1D(basis, &field_P_1d));
else CeedCallBackend(CeedBasisGetNumNodes(basis, &field_P_1d));
*max_P_1d = CeedIntMax(*max_P_1d, field_P_1d);
Expand All @@ -69,7 +69,7 @@ static int CeedOperatorBuildKernelData_Cuda_gen(Ceed ceed, CeedInt num_input_fie

// Collect dim, P_1d, and Q_1d
CeedCallBackend(CeedBasisIsTensor(basis, &is_field_tensor));
*is_tensor = *is_tensor && is_field_tensor;
*is_tensor = *is_tensor && is_field_tensor;
if (is_field_tensor) CeedCallBackend(CeedBasisGetNumNodes1D(basis, &field_P_1d));
else CeedCallBackend(CeedBasisGetNumNodes(basis, &field_P_1d));
*max_P_1d = CeedIntMax(*max_P_1d, field_P_1d);
Expand Down Expand Up @@ -1040,7 +1040,7 @@ extern "C" int CeedOperatorBuildKernel_Cuda_gen(CeedOperator op) {
code << "// d_[in,out]_i: CeedVector device array\n";
code << "// r_[in,out]_e_i: Element vector register\n";
code << "// r_[in,out]_q_i: Quadrature space vector register\n";
code << "// r_[in,out]_c_i: AtPoints Chebyshev coefficents register\n";
code << "// r_[in,out]_c_i: AtPoints Chebyshev coefficients register\n";
code << "// r_[in,out]_s_i: Quadrature space slice vector register\n";
code << "// \n";
code << "// s_B_[in,out]_i: Interpolation matrix, shared memory\n";
Expand Down
6 changes: 3 additions & 3 deletions backends/cuda-gen/ceed-cuda-gen-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ static int CeedOperatorApplyAdd_Cuda_gen(CeedOperator op, CeedVector input_vec,
Ceed basis_ceed;

CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
is_all_tensor &= is_tensor;
is_all_tensor &= is_tensor;
is_all_nontensor &= !is_tensor;
CeedCallBackend(CeedBasisGetCeed(basis, &basis_ceed));
CeedCallBackend(CeedGetResource(basis_ceed, &resource));
Expand All @@ -150,7 +150,7 @@ static int CeedOperatorApplyAdd_Cuda_gen(CeedOperator op, CeedVector input_vec,
Ceed basis_ceed;

CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
is_all_tensor &= is_tensor;
is_all_tensor &= is_tensor;
is_all_nontensor &= !is_tensor;

CeedCallBackend(CeedBasisGetCeed(basis, &basis_ceed));
Expand All @@ -166,7 +166,7 @@ static int CeedOperatorApplyAdd_Cuda_gen(CeedOperator op, CeedVector input_vec,
if (!has_shared_bases || (!is_all_tensor && !is_all_nontensor)) {
CeedOperator op_fallback;

CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/cuda/ref CeedOperator due to large non-tensor bases");
CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/cuda/ref CeedOperator due unsupported bases");
CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
return CEED_ERROR_SUCCESS;
Expand Down
Loading

0 comments on commit 9123fb0

Please sign in to comment.