Skip to content

Commit 7bb0d67

Browse files
committed
Added explicit handling for conversion marker nodes in scalar evaluation
1 parent fefb99f commit 7bb0d67

2 files changed

Lines changed: 105 additions & 0 deletions

File tree

src/functions.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3784,6 +3784,29 @@ static double me_eval_scalar(const me_expr* n) {
37843784
case ME_FUNCTION5:
37853785
case ME_FUNCTION6:
37863786
case ME_FUNCTION7:
3787+
if (ARITY(n->type) == 1 && n->function == NULL) {
3788+
/* Internal conversion marker node created by apply_type_promotion(). */
3789+
double v = M(0);
3790+
switch (n->dtype) {
3791+
case ME_AUTO:
3792+
case ME_FLOAT64: return v;
3793+
case ME_BOOL: return ((bool)v) ? 1.0 : 0.0;
3794+
case ME_INT8: return (double)(int8_t)v;
3795+
case ME_INT16: return (double)(int16_t)v;
3796+
case ME_INT32: return (double)(int32_t)v;
3797+
case ME_INT64: return (double)(int64_t)v;
3798+
case ME_UINT8: return (double)(uint8_t)v;
3799+
case ME_UINT16: return (double)(uint16_t)v;
3800+
case ME_UINT32: return (double)(uint32_t)v;
3801+
case ME_UINT64: return (double)(uint64_t)v;
3802+
case ME_FLOAT32: return (double)(float)v;
3803+
case ME_COMPLEX64:
3804+
case ME_COMPLEX128:
3805+
case ME_STRING:
3806+
default:
3807+
return NAN;
3808+
}
3809+
}
37873810
switch (ARITY(n->type)) {
37883811
case 0: return ME_FUN(void)();
37893812
case 1: return ME_FUN(double)(M(0));

tests/test_nd.c

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,83 @@ static int test_nd_python_blosc2_contract_full_int_cast_ramp_with_input_jit_mode
698698
return status;
699699
}
700700

701+
static int test_nd_python_blosc2_contract_constant_subexpr_float64_to_float32_jit_modes(void) {
702+
int status = 0;
703+
int err = 0;
704+
int64_t shape[2] = {3, 5};
705+
int32_t chunkshape[2] = {2, 4};
706+
int32_t blockshape[2] = {2, 3};
707+
me_variable vars[] = {{"x", ME_FLOAT64}};
708+
const char *source =
709+
"def kernel_const_subexpr(x):\n"
710+
" return (0.1 + 0.5) * x\n";
711+
const struct {
712+
int compile_jit_mode;
713+
me_jit_mode eval_jit_mode;
714+
const char *label;
715+
} jit_modes[] = {
716+
{ME_JIT_DEFAULT, ME_JIT_DEFAULT, "default"},
717+
{ME_JIT_OFF, ME_JIT_OFF, "jit_off"},
718+
{ME_JIT_ON, ME_JIT_ON, "jit_on"},
719+
};
720+
721+
for (int mode = 0; mode < (int)(sizeof(jit_modes) / sizeof(jit_modes[0])); mode++) {
722+
me_expr *expr = NULL;
723+
int rc = me_compile_nd_jit(source, vars, 1, ME_FLOAT32, 2,
724+
shape, chunkshape, blockshape,
725+
jit_modes[mode].compile_jit_mode, &err, &expr);
726+
if (rc != ME_COMPILE_SUCCESS) {
727+
printf("FAILED python-blosc2 constant-subexpr float64->float32 me_compile_nd_jit (%s): rc=%d err=%d\n",
728+
jit_modes[mode].label, rc, err);
729+
status = 1;
730+
goto cleanup;
731+
}
732+
733+
int64_t valid = -1;
734+
rc = me_nd_valid_nitems(expr, 1, 0, &valid);
735+
if (rc != ME_EVAL_SUCCESS || valid != 2) {
736+
printf("FAILED python-blosc2 constant-subexpr float64->float32 me_nd_valid_nitems (%s): rc=%d valid=%lld\n",
737+
jit_modes[mode].label, rc, (long long)valid);
738+
me_free(expr);
739+
status = 1;
740+
goto cleanup;
741+
}
742+
743+
double xblock[6] = {0.8, 999.0, 999.0, 1.2, 999.0, 999.0};
744+
const void *inputs[] = {xblock};
745+
float out[6] = {-1.f, -1.f, -1.f, -1.f, -1.f, -1.f};
746+
const float expected[6] = {0.48f, 0.f, 0.f, 0.72f, 0.f, 0.f};
747+
748+
me_eval_params eval_params = {0};
749+
eval_params.disable_simd = false;
750+
eval_params.simd_ulp_mode = ME_SIMD_ULP_3_5;
751+
eval_params.jit_mode = jit_modes[mode].eval_jit_mode;
752+
rc = me_eval_nd(expr, inputs, 1, out, 6, 1, 0, &eval_params);
753+
if (rc != ME_EVAL_SUCCESS) {
754+
printf("FAILED python-blosc2 constant-subexpr float64->float32 me_eval_nd (%s): rc=%d\n",
755+
jit_modes[mode].label, rc);
756+
me_free(expr);
757+
status = 1;
758+
goto cleanup;
759+
}
760+
761+
for (int i = 0; i < 6; i++) {
762+
if (fabsf(out[i] - expected[i]) > 1e-6f) {
763+
printf("FAILED python-blosc2 constant-subexpr float64->float32 mismatch (%s): idx=%d got=%g exp=%g\n",
764+
jit_modes[mode].label, i, (double)out[i], (double)expected[i]);
765+
me_free(expr);
766+
status = 1;
767+
goto cleanup;
768+
}
769+
}
770+
771+
me_free(expr);
772+
}
773+
774+
cleanup:
775+
return status;
776+
}
777+
701778
static int run_nd_compile_failure_reason_case(const char *case_name,
702779
const char *source,
703780
const me_variable *vars,
@@ -1727,6 +1804,11 @@ int main(void) {
17271804
failed |= t24;
17281805
printf("Result: %s\n\n", t24 ? "FAIL" : "PASS");
17291806

1807+
printf("Test 25: python-blosc2 constant-subexpr float64->float32 contract (jit default/off/on)\n");
1808+
int t25 = test_nd_python_blosc2_contract_constant_subexpr_float64_to_float32_jit_modes();
1809+
failed |= t25;
1810+
printf("Result: %s\n\n", t25 ? "FAIL" : "PASS");
1811+
17301812
printf("=====================\n");
17311813
printf("Summary: %s\n", failed ? "FAIL" : "PASS");
17321814
return failed ? 1 : 0;

0 commit comments

Comments
 (0)