Skip to content

Commit e036be4

Browse files
authored
Merge pull request #1699 from CEED/jeremy/set-jit-defines
Add CeedAddJitDefine
2 parents 1dc8b1e + 830fc37 commit e036be4

File tree

9 files changed

+173
-22
lines changed

9 files changed

+173
-22
lines changed

Diff for: backends/cuda/ceed-cuda-compile.cpp

+23-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
3838
size_t ptx_size;
3939
char *ptx;
4040
const int num_opts = 4;
41-
CeedInt num_jit_source_dirs = 0;
41+
CeedInt num_jit_source_dirs = 0, num_jit_defines = 0;
4242
const char **opts;
4343
nvrtcProgram prog;
4444
struct cudaDeviceProp prop;
@@ -85,19 +85,34 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
8585
opts[1] = arch_arg.c_str();
8686
opts[2] = "-Dint32_t=int";
8787
opts[3] = "-DCEED_RUNNING_JIT_PASS=1";
88+
// Additional include dirs
8889
{
8990
const char **jit_source_dirs;
9091

9192
CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
9293
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs, &opts));
9394
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
94-
std::ostringstream include_dirs_arg;
95+
std::ostringstream include_dir_arg;
9596

96-
include_dirs_arg << "-I" << jit_source_dirs[i];
97-
CeedCallBackend(CeedStringAllocCopy(include_dirs_arg.str().c_str(), (char **)&opts[num_opts + i]));
97+
include_dir_arg << "-I" << jit_source_dirs[i];
98+
CeedCallBackend(CeedStringAllocCopy(include_dir_arg.str().c_str(), (char **)&opts[num_opts + i]));
9899
}
99100
CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
100101
}
102+
// User defines
103+
{
104+
const char **jit_defines;
105+
106+
CeedCallBackend(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
107+
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs + num_jit_defines, &opts));
108+
for (CeedInt i = 0; i < num_jit_defines; i++) {
109+
std::ostringstream define_arg;
110+
111+
define_arg << "-D" << jit_defines[i];
112+
CeedCallBackend(CeedStringAllocCopy(define_arg.str().c_str(), (char **)&opts[num_opts + num_jit_source_dirs + i]));
113+
}
114+
CeedCallBackend(CeedRestoreJitDefines(ceed, &jit_defines));
115+
}
101116

102117
// Add string source argument provided in call
103118
code << source;
@@ -106,11 +121,14 @@ int CeedCompile_Cuda(Ceed ceed, const char *source, CUmodule *module, const Ceed
106121
CeedCallNvrtc(ceed, nvrtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));
107122

108123
// Compile kernel
109-
nvrtcResult result = nvrtcCompileProgram(prog, num_opts + num_jit_source_dirs, opts);
124+
nvrtcResult result = nvrtcCompileProgram(prog, num_opts + num_jit_source_dirs + num_jit_defines, opts);
110125

111126
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
112127
CeedCallBackend(CeedFree(&opts[num_opts + i]));
113128
}
129+
for (CeedInt i = 0; i < num_jit_defines; i++) {
130+
CeedCallBackend(CeedFree(&opts[num_opts + num_jit_source_dirs + i]));
131+
}
114132
CeedCallBackend(CeedFree(&opts));
115133
if (result != NVRTC_SUCCESS) {
116134
char *log;

Diff for: backends/hip/ceed-hip-compile.cpp

+23-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const Ce
3737
size_t ptx_size;
3838
char *ptx;
3939
const int num_opts = 4;
40-
CeedInt num_jit_source_dirs = 0;
40+
CeedInt num_jit_source_dirs = 0, num_jit_defines = 0;
4141
const char **opts;
4242
int runtime_version;
4343
hiprtcProgram prog;
@@ -87,19 +87,34 @@ int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const Ce
8787
opts[1] = arch_arg.c_str();
8888
opts[2] = "-munsafe-fp-atomics";
8989
opts[3] = "-DCEED_RUNNING_JIT_PASS=1";
90+
// Additional include dirs
9091
{
9192
const char **jit_source_dirs;
9293

9394
CeedCallBackend(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
9495
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs, &opts));
9596
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
96-
std::ostringstream include_dirs_arg;
97+
std::ostringstream include_dir_arg;
9798

98-
include_dirs_arg << "-I" << jit_source_dirs[i];
99-
CeedCallBackend(CeedStringAllocCopy(include_dirs_arg.str().c_str(), (char **)&opts[num_opts + i]));
99+
include_dir_arg << "-I" << jit_source_dirs[i];
100+
CeedCallBackend(CeedStringAllocCopy(include_dir_arg.str().c_str(), (char **)&opts[num_opts + i]));
100101
}
101102
CeedCallBackend(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
102103
}
104+
// User defines
105+
{
106+
const char **jit_defines;
107+
108+
CeedCallBackend(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
109+
CeedCallBackend(CeedRealloc(num_opts + num_jit_source_dirs + num_jit_defines, &opts));
110+
for (CeedInt i = 0; i < num_jit_defines; i++) {
111+
std::ostringstream define_arg;
112+
113+
define_arg << "-D" << jit_defines[i];
114+
CeedCallBackend(CeedStringAllocCopy(define_arg.str().c_str(), (char **)&opts[num_opts + num_jit_source_dirs + i]));
115+
}
116+
CeedCallBackend(CeedRestoreJitDefines(ceed, &jit_defines));
117+
}
103118

104119
// Add string source argument provided in call
105120
code << source;
@@ -108,11 +123,14 @@ int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const Ce
108123
CeedCallHiprtc(ceed, hiprtcCreateProgram(&prog, code.str().c_str(), NULL, 0, NULL, NULL));
109124

110125
// Compile kernel
111-
hiprtcResult result = hiprtcCompileProgram(prog, num_opts + num_jit_source_dirs, opts);
126+
hiprtcResult result = hiprtcCompileProgram(prog, num_opts + num_jit_source_dirs + num_jit_defines, opts);
112127

113128
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
114129
CeedCallBackend(CeedFree(&opts[num_opts + i]));
115130
}
131+
for (CeedInt i = 0; i < num_jit_defines; i++) {
132+
CeedCallBackend(CeedFree(&opts[num_opts + num_jit_source_dirs + i]));
133+
}
116134
CeedCallBackend(CeedFree(&opts));
117135
if (result != HIPRTC_SUCCESS) {
118136
size_t log_size;

Diff for: doc/sphinx/source/releasenotes.md

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ On this page we provide a summary of the main API changes, new features and exam
1919
- Add `CeedElemRestrictionGetLLayout` to provide L-vector layout for strided `CeedElemRestriction` created with `CEED_BACKEND_STRIDES`.
2020
- Add `CeedVectorReturnCeed` and similar when parent `Ceed` context for a libCEED object is only needed once in a calling scope.
2121
- Enable `#pragma once` for all JiT source; remove duplicate includes in JiT source string before compilation.
22+
- Allow user to set additional compiler options for CUDA and HIP JiT.
23+
Specifically, directories set with `CeedAddJitSourceRoot(ceed, "foo/bar")` will be used to set `-Ifoo/bar` and defines set with `CeedAddJitDefine(ceed, "foo=bar")` will be used to set `-Dfoo=bar`.
2224

2325
### Examples
2426

Diff for: include/ceed-impl.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ struct Ceed_private {
9999
Ceed op_fallback_ceed, op_fallback_parent;
100100
const char *op_fallback_resource;
101101
char **jit_source_roots;
102-
CeedInt num_jit_source_roots;
102+
CeedInt num_jit_source_roots, max_jit_source_roots, num_jit_source_roots_readers;
103+
char **jit_defines;
104+
CeedInt num_jit_defines, max_jit_defines, num_jit_defines_readers;
103105
int (*Error)(Ceed, const char *, int, const char *, int, const char *, va_list *);
104106
int (*SetStream)(Ceed, void *);
105107
int (*GetPreferredMemType)(CeedMemType *);

Diff for: include/ceed/backend.h

+2
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ CEED_EXTERN int CeedGetWorkVector(Ceed ceed, CeedSize len, CeedVector *vec);
256256
CEED_EXTERN int CeedRestoreWorkVector(Ceed ceed, CeedVector *vec);
257257
CEED_EXTERN int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***jit_source_roots);
258258
CEED_EXTERN int CeedRestoreJitSourceRoots(Ceed ceed, const char ***jit_source_roots);
259+
CEED_EXTERN int CeedGetJitDefines(Ceed ceed, CeedInt *num_defines, const char ***jit_defines);
260+
CEED_EXTERN int CeedRestoreJitDefines(Ceed ceed, const char ***jit_defines);
259261

260262
CEED_EXTERN int CeedVectorHasValidArray(CeedVector vec, bool *has_valid_array);
261263
CEED_EXTERN int CeedVectorHasBorrowedArrayOfType(CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type);

Diff for: include/ceed/ceed.h

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ CEED_EXTERN int CeedReferenceCopy(Ceed ceed, Ceed *ceed_copy);
107107
CEED_EXTERN int CeedGetResource(Ceed ceed, const char **resource);
108108
CEED_EXTERN int CeedIsDeterministic(Ceed ceed, bool *is_deterministic);
109109
CEED_EXTERN int CeedAddJitSourceRoot(Ceed ceed, const char *jit_source_root);
110+
CEED_EXTERN int CeedAddJitDefine(Ceed ceed, const char *jit_define);
110111
CEED_EXTERN int CeedView(Ceed ceed, FILE *stream);
111112
CEED_EXTERN int CeedDestroy(Ceed *ceed);
112113
CEED_EXTERN int CeedErrorImpl(Ceed ceed, const char *filename, int lineno, const char *func, int ecode, const char *format, ...);

Diff for: interface/ceed.c

+110-8
Original file line numberDiff line numberDiff line change
@@ -659,14 +659,24 @@ int CeedGetOperatorFallbackCeed(Ceed ceed, Ceed *fallback_ceed) {
659659
fallback_ceed->Error = ceed->Error;
660660
ceed->op_fallback_ceed = fallback_ceed;
661661
{
662-
const char **jit_source_dirs;
663-
CeedInt num_jit_source_dirs = 0;
662+
const char **jit_source_roots;
663+
CeedInt num_jit_source_roots = 0;
664664

665-
CeedCall(CeedGetJitSourceRoots(ceed, &num_jit_source_dirs, &jit_source_dirs));
666-
for (CeedInt i = 0; i < num_jit_source_dirs; i++) {
667-
CeedCall(CeedAddJitSourceRoot(fallback_ceed, jit_source_dirs[i]));
665+
CeedCall(CeedGetJitSourceRoots(ceed, &num_jit_source_roots, &jit_source_roots));
666+
for (CeedInt i = 0; i < num_jit_source_roots; i++) {
667+
CeedCall(CeedAddJitSourceRoot(fallback_ceed, jit_source_roots[i]));
668668
}
669-
CeedCall(CeedRestoreJitSourceRoots(ceed, &jit_source_dirs));
669+
CeedCall(CeedRestoreJitSourceRoots(ceed, &jit_source_roots));
670+
}
671+
{
672+
const char **jit_defines;
673+
CeedInt num_jit_defines = 0;
674+
675+
CeedCall(CeedGetJitDefines(ceed, &num_jit_defines, &jit_defines));
676+
for (CeedInt i = 0; i < num_jit_defines; i++) {
677+
CeedCall(CeedAddJitSourceRoot(fallback_ceed, jit_defines[i]));
678+
}
679+
CeedCall(CeedRestoreJitDefines(ceed, &jit_defines));
670680
}
671681
}
672682
*fallback_ceed = ceed->op_fallback_ceed;
@@ -874,7 +884,7 @@ int CeedRestoreWorkVector(Ceed ceed, CeedVector *vec) {
874884
}
875885

876886
/**
877-
@brief Retrieve list ofadditional JiT source roots from `Ceed` context.
887+
@brief Retrieve list of additional JiT source roots from `Ceed` context.
878888
879889
Note: The caller is responsible for restoring `jit_source_roots` with @ref CeedRestoreJitSourceRoots().
880890
@@ -892,6 +902,7 @@ int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***ji
892902
CeedCall(CeedGetParent(ceed, &ceed_parent));
893903
*num_source_roots = ceed_parent->num_jit_source_roots;
894904
*jit_source_roots = (const char **)ceed_parent->jit_source_roots;
905+
ceed_parent->num_jit_source_roots_readers++;
895906
return CEED_ERROR_SUCCESS;
896907
}
897908

@@ -906,7 +917,53 @@ int CeedGetJitSourceRoots(Ceed ceed, CeedInt *num_source_roots, const char ***ji
906917
@ref Backend
907918
**/
908919
int CeedRestoreJitSourceRoots(Ceed ceed, const char ***jit_source_roots) {
920+
Ceed ceed_parent;
921+
922+
CeedCall(CeedGetParent(ceed, &ceed_parent));
909923
*jit_source_roots = NULL;
924+
ceed_parent->num_jit_source_roots_readers--;
925+
return CEED_ERROR_SUCCESS;
926+
}
927+
928+
/**
929+
@brief Retrieve list of additional JiT defines from `Ceed` context.
930+
931+
Note: The caller is responsible for restoring `jit_defines` with @ref CeedRestoreJitDefines().
932+
933+
@param[in] ceed `Ceed` context
934+
@param[out] num_jit_defines Number of JiT defines
935+
@param[out] jit_defines Strings such as `foo=bar`, used as `-Dfoo=bar` in JiT
936+
937+
@return An error code: 0 - success, otherwise - failure
938+
939+
@ref Backend
940+
**/
941+
int CeedGetJitDefines(Ceed ceed, CeedInt *num_defines, const char ***jit_defines) {
942+
Ceed ceed_parent;
943+
944+
CeedCall(CeedGetParent(ceed, &ceed_parent));
945+
*num_defines = ceed_parent->num_jit_defines;
946+
*jit_defines = (const char **)ceed_parent->jit_defines;
947+
ceed_parent->num_jit_defines_readers++;
948+
return CEED_ERROR_SUCCESS;
949+
}
950+
951+
/**
952+
@brief Restore list of additional JiT defines from with @ref CeedGetJitDefines()
953+
954+
@param[in] ceed `Ceed` context
955+
@param[out] jit_defines String such as `foo=bar`, used as `-Dfoo=bar` in JiT
956+
957+
@return An error code: 0 - success, otherwise - failure
958+
959+
@ref Backend
960+
**/
961+
int CeedRestoreJitDefines(Ceed ceed, const char ***jit_defines) {
962+
Ceed ceed_parent;
963+
964+
CeedCall(CeedGetParent(ceed, &ceed_parent));
965+
*jit_defines = NULL;
966+
ceed_parent->num_jit_defines_readers--;
910967
return CEED_ERROR_SUCCESS;
911968
}
912969

@@ -1290,17 +1347,52 @@ int CeedAddJitSourceRoot(Ceed ceed, const char *jit_source_root) {
12901347
Ceed ceed_parent;
12911348

12921349
CeedCall(CeedGetParent(ceed, &ceed_parent));
1350+
CeedCheck(!ceed_parent->num_jit_source_roots_readers, ceed, CEED_ERROR_ACCESS, "Cannot add JiT source root, read access has not been restored");
12931351

12941352
CeedInt index = ceed_parent->num_jit_source_roots;
12951353
size_t path_length = strlen(jit_source_root);
12961354

1297-
CeedCall(CeedRealloc(index + 1, &ceed_parent->jit_source_roots));
1355+
if (ceed_parent->num_jit_source_roots == ceed_parent->max_jit_source_roots) {
1356+
if (ceed_parent->max_jit_source_roots == 0) ceed_parent->max_jit_source_roots = 1;
1357+
ceed_parent->max_jit_source_roots *= 2;
1358+
CeedCall(CeedRealloc(ceed_parent->max_jit_source_roots, &ceed_parent->jit_source_roots));
1359+
}
12981360
CeedCall(CeedCalloc(path_length + 1, &ceed_parent->jit_source_roots[index]));
12991361
memcpy(ceed_parent->jit_source_roots[index], jit_source_root, path_length);
13001362
ceed_parent->num_jit_source_roots++;
13011363
return CEED_ERROR_SUCCESS;
13021364
}
13031365

1366+
/**
1367+
@brief Set additional JiT compiler define for `Ceed` context
1368+
1369+
@param[in,out] ceed `Ceed` context
1370+
@param[in] jit_define String such as `foo=bar`, used as `-Dfoo=bar` in JiT
1371+
1372+
@return An error code: 0 - success, otherwise - failure
1373+
1374+
@ref User
1375+
**/
1376+
int CeedAddJitDefine(Ceed ceed, const char *jit_define) {
1377+
Ceed ceed_parent;
1378+
1379+
CeedCall(CeedGetParent(ceed, &ceed_parent));
1380+
CeedCheck(!ceed_parent->num_jit_defines_readers, ceed, CEED_ERROR_ACCESS, "Cannot add JiT define, read access has not been restored");
1381+
1382+
CeedInt index = ceed_parent->num_jit_defines;
1383+
size_t define_length = strlen(jit_define);
1384+
1385+
if (ceed_parent->num_jit_defines == ceed_parent->max_jit_defines) {
1386+
if (ceed_parent->max_jit_defines == 0) ceed_parent->max_jit_defines = 1;
1387+
ceed_parent->max_jit_defines *= 2;
1388+
CeedCall(CeedRealloc(ceed_parent->max_jit_defines, &ceed_parent->jit_defines));
1389+
}
1390+
CeedCall(CeedCalloc(define_length + 1, &ceed_parent->jit_defines[index]));
1391+
memcpy(ceed_parent->jit_defines[index], jit_define, define_length);
1392+
ceed_parent->num_jit_defines++;
1393+
return CEED_ERROR_SUCCESS;
1394+
}
1395+
13041396
/**
13051397
@brief View a `Ceed`
13061398
@@ -1338,6 +1430,11 @@ int CeedDestroy(Ceed *ceed) {
13381430
*ceed = NULL;
13391431
return CEED_ERROR_SUCCESS;
13401432
}
1433+
1434+
CeedCheck(!(*ceed)->num_jit_source_roots_readers, *ceed, CEED_ERROR_ACCESS,
1435+
"Cannot destroy ceed context, read access for JiT source roots has been granted");
1436+
CeedCheck(!(*ceed)->num_jit_defines_readers, *ceed, CEED_ERROR_ACCESS, "Cannot add JiT source root, read access for JiT defines has been granted");
1437+
13411438
if ((*ceed)->delegate) CeedCall(CeedDestroy(&(*ceed)->delegate));
13421439

13431440
if ((*ceed)->obj_delegate_count > 0) {
@@ -1355,6 +1452,11 @@ int CeedDestroy(Ceed *ceed) {
13551452
}
13561453
CeedCall(CeedFree(&(*ceed)->jit_source_roots));
13571454

1455+
for (CeedInt i = 0; i < (*ceed)->num_jit_defines; i++) {
1456+
CeedCall(CeedFree(&(*ceed)->jit_defines[i]));
1457+
}
1458+
CeedCall(CeedFree(&(*ceed)->jit_defines));
1459+
13581460
CeedCall(CeedFree(&(*ceed)->f_offsets));
13591461
CeedCall(CeedFree(&(*ceed)->resource));
13601462
CeedCall(CeedDestroy(&(*ceed)->op_fallback_ceed));

Diff for: tests/t406-qfunction.c

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ int main(int argc, char **argv) {
2424

2525
memcpy(&file_path[last_slash - file_path], "/test-include/", 15);
2626
CeedAddJitSourceRoot(ceed, file_path);
27+
CeedAddJitDefine(ceed, "COMPILER_DEFINED_SCALE=42");
2728
}
2829

2930
CeedVectorCreate(ceed, q, &w);
@@ -71,9 +72,9 @@ int main(int argc, char **argv) {
7172

7273
CeedVectorGetArrayRead(v, CEED_MEM_HOST, &v_array);
7374
for (CeedInt i = 0; i < q; i++) {
74-
if (fabs(5 * v_true[i] * sqrt(2.) - v_array[i]) > 1E3 * CEED_EPSILON) {
75+
if (fabs(5 * COMPILER_DEFINED_SCALE * v_true[i] * sqrt(2.) - v_array[i]) > 5E3 * CEED_EPSILON) {
7576
// LCOV_EXCL_START
76-
printf("[%" CeedInt_FMT "] v_true %f != v %f\n", i, 5 * v_true[i] * sqrt(2.), v_array[i]);
77+
printf("[%" CeedInt_FMT "] v_true %f != v %f\n", i, 5 * COMPILER_DEFINED_SCALE * v_true[i] * sqrt(2.), v_array[i]);
7778
// LCOV_EXCL_STOP
7879
}
7980
}

Diff for: tests/t406-qfunction.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
# include "t406-qfunction-scales.h"
2424
// clang-format on
2525

26+
// Extra define set via CeedAddJitDefine() during JiT
27+
#ifndef CEED_RUNNING_JIT_PASS
28+
#define COMPILER_DEFINED_SCALE 42
29+
#endif
30+
2631
CEED_QFUNCTION(setup)(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out) {
2732
const CeedScalar *w = in[0];
2833
CeedScalar *q_data = out[0];
@@ -36,7 +41,7 @@ CEED_QFUNCTION(mass)(void *ctx, const CeedInt Q, const CeedScalar *const *in, Ce
3641
const CeedScalar *q_data = in[0], *u = in[1];
3742
CeedScalar *v = out[0];
3843
for (CeedInt i = 0; i < Q; i++) {
39-
v[i] = q_data[i] * (times_two(u[i]) + times_three(u[i])) * sqrt(1.0 * SCALE_TWO);
44+
v[i] = q_data[i] * COMPILER_DEFINED_SCALE * (times_two(u[i]) + times_three(u[i])) * sqrt(1.0 * SCALE_TWO);
4045
}
4146
return 0;
4247
}

0 commit comments

Comments
 (0)