Skip to content

Commit fefb99f

Browse files
committed
Add a DSL for Black-Scholes computation (specialized constants support)
1 parent 3c9ffe2 commit fefb99f

1 file changed

Lines changed: 116 additions & 0 deletions

File tree

tests/test_dsl_syntax.c

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,121 @@ static int test_black_scholes_dsl_kernel_support(void) {
14791479
return 0;
14801480
}
14811481

1482+
static int test_black_scholes_dsl_kernel_specialized_constants_support(void) {
1483+
printf("\n=== DSL Test 10c: Black-Scholes specialized-constants kernel support ===\n");
1484+
1485+
const char *src =
1486+
"def kernel(S, X, T):\n"
1487+
" A1 = 0.31938153\n"
1488+
" A2 = -0.356563782\n"
1489+
" A3 = 1.781477937\n"
1490+
" A4 = -1.821255978\n"
1491+
" A5 = 1.330274429\n"
1492+
" RSQRT2PI = 0.39894228040143267793994605993438\n"
1493+
"\n"
1494+
" sqrtT = sqrt(T)\n"
1495+
" d1 = (log(S / X) + (0.1 + 0.5 * 1.0 * 1.0) * T) / (1.0 * sqrtT)\n"
1496+
" d2 = d1 - 1.0 * sqrtT\n"
1497+
" K = 1.0 / (1.0 + 0.2316419 * abs(d1))\n"
1498+
"\n"
1499+
" ret_val = (RSQRT2PI * exp(-0.5 * d1 * d1) * (K * (A1 + K * (A2 + K * (A3 + K * (A4 + K * A5))))))\n"
1500+
" cndd1 = ret_val\n"
1501+
" if d1 > 0:\n"
1502+
" cndd1 = 1.0 - ret_val\n"
1503+
" else:\n"
1504+
" cndd1 = ret_val\n"
1505+
"\n"
1506+
" K = 1.0 / (1.0 + 0.2316419 * abs(d2))\n"
1507+
" ret_val = (RSQRT2PI * exp(-0.5 * d2 * d2) * (K * (A1 + K * (A2 + K * (A3 + K * (A4 + K * A5))))))\n"
1508+
" if d2 > 0:\n"
1509+
" cndd2 = 1.0 - ret_val\n"
1510+
" else:\n"
1511+
" cndd2 = ret_val\n"
1512+
"\n"
1513+
" expRT = exp((-1.0 * 0.1) * T)\n"
1514+
" callResult = (S * cndd1 - X * expRT * cndd2)\n"
1515+
" return callResult\n";
1516+
1517+
const int n = 8;
1518+
double S[8] = {100.0, 105.0, 110.0, 95.0, 120.0, 80.0, 150.0, 60.0};
1519+
double X[8] = {100.0, 100.0, 115.0, 90.0, 110.0, 85.0, 140.0, 65.0};
1520+
double T[8] = {1.0, 0.5, 2.0, 1.5, 0.25, 3.0, 0.75, 1.2};
1521+
const void *inputs[] = {S, X, T};
1522+
1523+
double expected[8];
1524+
for (int i = 0; i < n; i++) {
1525+
double A1 = 0.31938153;
1526+
double A2 = -0.356563782;
1527+
double A3 = 1.781477937;
1528+
double A4 = -1.821255978;
1529+
double A5 = 1.330274429;
1530+
double RSQRT2PI = 0.39894228040143267793994605993438;
1531+
1532+
double sqrtT = sqrt(T[i]);
1533+
double d1 = (log(S[i] / X[i]) + (0.1 + 0.5 * 1.0 * 1.0) * T[i]) / (1.0 * sqrtT);
1534+
double d2 = d1 - 1.0 * sqrtT;
1535+
double K = 1.0 / (1.0 + 0.2316419 * fabs(d1));
1536+
double ret_val = RSQRT2PI * exp(-0.5 * d1 * d1) *
1537+
(K * (A1 + K * (A2 + K * (A3 + K * (A4 + K * A5)))));
1538+
double cndd1 = (d1 > 0.0) ? (1.0 - ret_val) : ret_val;
1539+
1540+
K = 1.0 / (1.0 + 0.2316419 * fabs(d2));
1541+
ret_val = RSQRT2PI * exp(-0.5 * d2 * d2) *
1542+
(K * (A1 + K * (A2 + K * (A3 + K * (A4 + K * A5)))));
1543+
double cndd2 = (d2 > 0.0) ? (1.0 - ret_val) : ret_val;
1544+
1545+
double expRT = exp((-1.0 * 0.1) * T[i]);
1546+
expected[i] = S[i] * cndd1 - X[i] * expRT * cndd2;
1547+
}
1548+
1549+
me_variable vars[] = {
1550+
{"S", ME_FLOAT64},
1551+
{"X", ME_FLOAT64},
1552+
{"T", ME_FLOAT64}
1553+
};
1554+
double out_interp[8] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
1555+
double out_jit[8] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
1556+
char *saved_jit = dup_env_value("ME_DSL_JIT");
1557+
1558+
if (setenv("ME_DSL_JIT", "0", 1) != 0) {
1559+
printf(" ❌ FAILED: setenv ME_DSL_JIT=0 failed\n");
1560+
free(saved_jit);
1561+
return 1;
1562+
}
1563+
if (compile_eval_double(src, vars, 3, inputs, n, out_interp) != 0) {
1564+
restore_env_value("ME_DSL_JIT", saved_jit);
1565+
free(saved_jit);
1566+
return 1;
1567+
}
1568+
1569+
if (setenv("ME_DSL_JIT", "1", 1) != 0) {
1570+
printf(" ❌ FAILED: setenv ME_DSL_JIT=1 failed\n");
1571+
restore_env_value("ME_DSL_JIT", saved_jit);
1572+
free(saved_jit);
1573+
return 1;
1574+
}
1575+
if (compile_eval_double(src, vars, 3, inputs, n, out_jit) != 0) {
1576+
restore_env_value("ME_DSL_JIT", saved_jit);
1577+
free(saved_jit);
1578+
return 1;
1579+
}
1580+
1581+
restore_env_value("ME_DSL_JIT", saved_jit);
1582+
free(saved_jit);
1583+
1584+
if (check_all_close(out_interp, expected, n, 1e-6) != 0) {
1585+
printf(" ❌ FAILED: unexpected interpreter output\n");
1586+
return 1;
1587+
}
1588+
if (check_all_close(out_jit, out_interp, n, 1e-10) != 0) {
1589+
printf(" ❌ FAILED: interpreter/JIT mismatch\n");
1590+
return 1;
1591+
}
1592+
1593+
printf(" ✅ PASSED\n");
1594+
return 0;
1595+
}
1596+
14821597
static int test_loop_condition_policy(void) {
14831598
printf("\n=== DSL Test 11: loop condition policy ===\n");
14841599

@@ -2259,6 +2374,7 @@ int main(void) {
22592374
fail |= test_break_any_condition();
22602375
fail |= test_dsl_function_calls();
22612376
fail |= test_black_scholes_dsl_kernel_support();
2377+
fail |= test_black_scholes_dsl_kernel_specialized_constants_support();
22622378
fail |= test_loop_condition_policy();
22632379
fail |= test_elementwise_break();
22642380
fail |= test_reduction_any_remains_global();

0 commit comments

Comments
 (0)