Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Dec 11, 2024
1 parent 16898a3 commit 368f613
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 10 deletions.
57 changes: 47 additions & 10 deletions backends/cuda-gen/ceed-cuda-gen-operator-build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
code << " const CeedInt " << P_name << " = " << (basis == CEED_BASIS_NONE ? Q_1d : P_1d) << ";\n";
code << " const CeedInt num_comp" << var_suffix << " = " << num_comp << ";\n";
}
CeedCallBackend(CeedBasisDestroy(&basis));

// Load basis data
code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
Expand Down Expand Up @@ -240,6 +239,7 @@ static int CeedOperatorBuildKernelFieldData_Cuda_gen(std::ostringstream &code, C
break; // TODO: Not implemented
// LCOV_EXCL_STOP
}
CeedCallBackend(CeedBasisDestroy(&basis));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -319,10 +319,21 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
<< strides[2] << ">(data, elem, d" << var_suffix << ", r_e" << var_suffix << ");\n";
break;
}
case CEED_RESTRICTION_POINTS: {
CeedInt comp_stride;

CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
code << " const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
code << " // CompStride: " << comp_stride << "\n";
data->indices.inputs[i] = (CeedInt *)rstr_data->d_offsets;
code << " ReadLVecStandard" << dim << "d<num_comp" << var_suffix << ", " << comp_stride << ", " << P_name << ">(data, l_size"
<< var_suffix << ", elem, indices.inputs[" << i << "], d" << var_suffix << ", r_e" << var_suffix << ");\n";
break;
}
// LCOV_EXCL_START
case CEED_RESTRICTION_ORIENTED:
case CEED_RESTRICTION_CURL_ORIENTED:
case CEED_RESTRICTION_POINTS:
break; // TODO: Not implemented
// LCOV_EXCL_STOP
}
Expand Down Expand Up @@ -358,10 +369,21 @@ static int CeedOperatorBuildKernelRestriction_Cuda_gen(std::ostringstream &code,
<< strides[2] << ">(data, elem, r_e" << var_suffix << ", d" << var_suffix << ");\n";
break;
}
case CEED_RESTRICTION_POINTS: {
CeedInt comp_stride;

CeedCallBackend(CeedElemRestrictionGetLVectorSize(elem_rstr, &l_size));
code << " const CeedInt l_size" << var_suffix << " = " << l_size << ";\n";
CeedCallBackend(CeedElemRestrictionGetCompStride(elem_rstr, &comp_stride));
code << " // CompStride: " << comp_stride << "\n";
data->indices.outputs[i] = (CeedInt *)rstr_data->d_offsets;
code << " WriteLVecAtPoints" << dim << "d<num_comp" << var_suffix << ", " << comp_stride << ", " << P_name << ">(data, l_size" << var_suffix
<< ", elem, indices.outputs[" << i << "], points.num_per_elem, r_e" << var_suffix << ", d" << var_suffix << ");\n";
break;
}
// LCOV_EXCL_START
case CEED_RESTRICTION_ORIENTED:
case CEED_RESTRICTION_CURL_ORIENTED:
case CEED_RESTRICTION_POINTS:
break; // TODO: Not implemented
// LCOV_EXCL_STOP
}
Expand Down Expand Up @@ -406,20 +428,35 @@ static int CeedOperatorBuildKernelBasis_Cuda_gen(std::ostringstream &code, CeedO
}
break;
case CEED_EVAL_INTERP:
code << " CeedScalar r_q" << var_suffix << "[num_comp" << var_suffix << "*" << Q_name << "];\n";
code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d<num_comp" << var_suffix << ", P_1d" << var_suffix << ", " << Q_name
<< ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", r_q" << var_suffix << ");\n";
if (is_at_points) {
code << " CeedScalar r_q" << var_suffix << "[num_comp" << var_suffix << "*" << Q_name << "];\n";
code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d<num_comp" << var_suffix << ", P_1d" << var_suffix << ", " << Q_name
<< ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", r_q" << var_suffix << ");\n";
} else {
code << " CeedScalar r_q" << var_suffix << "[num_comp" << var_suffix << "*max_num_points];\n";
code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d<num_comp" << var_suffix << ", P_1d" << var_suffix << ", " << Q_name
<< ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", r_q" << var_suffix << ");\n";

code << " InterpAtPoints<" << dim << ", num_comp" << var_suffix << ", max_num_points, " << Q_name
<< ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", r_q" << var_suffix << ");\n";
}
break;
case CEED_EVAL_GRAD:
if (use_3d_slices) {
code << " CeedScalar r_q" << var_suffix << "[num_comp" << var_suffix << "*" << Q_name << "];\n";
code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d<num_comp" << var_suffix << ", P_1d" << var_suffix << ", " << Q_name
<< ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", r_q" << var_suffix << ");\n";
} else {
code << " CeedScalar r_q" << var_suffix << "[num_comp" << var_suffix << "*dim*" << Q_name << "];\n";
code << " Grad" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d<num_comp" << var_suffix
<< ", P_1d" << var_suffix << ", " << Q_name << ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", s_G" << var_suffix << ", r_q"
<< var_suffix << ");\n";
if (is_at_points) {
} else {
code << " CeedScalar r_q" << var_suffix << "[num_comp" << var_suffix << "*dim*max_num_points];\n";
code << " Interp" << (dim > 1 ? "Tensor" : "") << dim << "d<num_comp" << var_suffix << ", P_1d" << var_suffix << ", " << Q_name
<< ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", r_q" << var_suffix << ");\n";

code << " Grad" << (dim > 1 ? "Tensor" : "") << (dim == 3 && Q_1d >= P_1d ? "Collocated" : "") << dim << "d<num_comp" << var_suffix
<< ", P_1d" << var_suffix << ", " << Q_name << ">(data, r_e" << var_suffix << ", s_B" << var_suffix << ", s_G" << var_suffix << ", r_q"
<< var_suffix << ");\n";
}
}
break;
case CEED_EVAL_WEIGHT: {
Expand Down
1 change: 1 addition & 0 deletions backends/cuda-gen/ceed-cuda-gen.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ static int CeedInit_Cuda_gen(const char *resource, Ceed ceed) {

CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "QFunctionCreate", CeedQFunctionCreate_Cuda_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreate", CeedOperatorCreate_Cuda_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "OperatorCreateAtPoints", CeedOperatorCreate_Cuda_gen));
CeedCallBackend(CeedSetBackendFunction(ceed, "Ceed", ceed, "Destroy", CeedDestroy_Cuda));
return CEED_ERROR_SUCCESS;
}
Expand Down
49 changes: 49 additions & 0 deletions include/ceed/jit-source/cuda/cuda-gen-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ inline __device__ void WriteLVecStandard1d(SharedData_Cuda &data, const CeedInt
}
}

//------------------------------------------------------------------------------
// E-vector -> L-vector, AtPoints
//------------------------------------------------------------------------------
template <int NUM_COMP, int COMP_STRIDE, int P_1d>
inline __device__ void writeDofsAtPoints1d(SharedData_Cuda &data, const CeedInt num_nodes, const CeedInt elem, const CeedInt *__restrict__ indices,
const CeedInt *__restrict__ points_per_elem, const CeedScalar *__restrict__ r_v, CeedScalar *__restrict__ d_v) {
if (data.t_id_x < P_1d) {
const CeedInt node = data.t_id_x;
const CeedInt ind = indices[node + elem * P_1d];

if (node < points_per_elem[elem]) {
for (CeedInt comp = 0; comp < NUM_COMP; comp++) atomicAdd(&d_v[ind + COMP_STRIDE * comp], r_v[comp]);
}
}
}

//------------------------------------------------------------------------------
// E-vector -> L-vector, strided
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -123,6 +139,22 @@ inline __device__ void WriteLVecStandard2d(SharedData_Cuda &data, const CeedInt
}
}

//------------------------------------------------------------------------------
// E-vector -> L-vector, AtPoints
//------------------------------------------------------------------------------
template <int NUM_COMP, int COMP_STRIDE, int P_1d>
inline __device__ void writeDofsAtPoints2d(SharedData_Cuda &data, const CeedInt num_nodes, const CeedInt elem, const CeedInt *__restrict__ indices,
const CeedInt *__restrict__ points_per_elem, const CeedScalar *__restrict__ r_v, CeedScalar *__restrict__ d_v) {
if (data.t_id_x < P_1d && data.t_id_y < P_1d) {
const CeedInt node = data.t_id_x + data.t_id_y * P_1d;
const CeedInt ind = indices[node + elem * P_1d * P_1d];

if (node < points_per_elem[elem]) {
for (CeedInt comp = 0; comp < NUM_COMP; comp++) atomicAdd(&d_v[ind + COMP_STRIDE * comp], r_v[comp]);
}
}
}

//------------------------------------------------------------------------------
// E-vector -> L-vector, strided
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -215,6 +247,23 @@ inline __device__ void WriteLVecStandard3d(SharedData_Cuda &data, const CeedInt
}
}

//------------------------------------------------------------------------------
// E-vector -> L-vector, AtPoints
//------------------------------------------------------------------------------
template <int NUM_COMP, int COMP_STRIDE, int P_1d>
inline __device__ void writeDofsAtPoints3d(SharedData_Cuda &data, const CeedInt num_nodes, const CeedInt elem, const CeedInt *__restrict__ indices,
const CeedInt *__restrict__ points_per_elem, const CeedScalar *__restrict__ r_v, CeedScalar *__restrict__ d_v) {
if (data.t_id_x < P_1d && data.t_id_y < P_1d)
for (CeedInt z = 0; z < P_1d; z++) {
const CeedInt node = data.t_id_x + data.t_id_y * P_1d + z * P_1d * P_1d;
const CeedInt ind = indices[node + elem * P_1d * P_1d * P_1d];

if (node < points_per_elem[elem]) {
for (CeedInt comp = 0; comp < NUM_COMP; comp++) atomicAdd(&d_v[ind + COMP_STRIDE * comp], r_v[z + comp * P_1d]);
}
}
}

//------------------------------------------------------------------------------
// E-vector -> L-vector, strided
//------------------------------------------------------------------------------
Expand Down

0 comments on commit 368f613

Please sign in to comment.