Skip to content

Commit 6924883

Browse files
committed
implements incomplete gamma function (RLGamma from bxdf chi2 test)
1 parent 01337d6 commit 6924883

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

include/nbl/builtin/hlsl/tgmath.hlsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,12 @@ inline T beta(NBL_CONST_REF_ARG(T) v1, NBL_CONST_REF_ARG(T) v2)
249249
return tgmath_impl::beta_helper<T>::__call(v1, v2)/tgmath_impl::beta_helper<T>::__call(T(1.0), T(1.0)); // ensure beta(1,1)==1
250250
}
251251

252+
template<typename T>
253+
inline T gamma(NBL_CONST_REF_ARG(T) a, NBL_CONST_REF_ARG(T) x)
254+
{
255+
return tgmath_impl::gamma_helper<T>::__call(a, x);
256+
}
257+
252258
}
253259
}
254260

include/nbl/builtin/hlsl/tgmath/impl.hlsl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
9898
struct lgamma_helper;
9999
template<typename T NBL_STRUCT_CONSTRAINABLE>
100100
struct beta_helper;
101+
template<typename T NBL_STRUCT_CONSTRAINABLE>
102+
struct gamma_helper;
101103

102104
#ifdef __HLSL_VERSION
103105

@@ -606,6 +608,88 @@ struct beta_helper<T NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T>) >
606608
}
607609
};
608610

611+
// incomplete gamma function
612+
template<typename T>
613+
NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<T>)
614+
struct gamma_helper<T NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T>) >
615+
{
616+
NBL_CONSTEXPR_STATIC_INLINE T epsilon = 1e-15;
617+
NBL_CONSTEXPR_STATIC_INLINE T big = 4503599627370496.0;
618+
NBL_CONSTEXPR_STATIC_INLINE T bigInv = 2.22044604925031308085e-16;
619+
620+
static T __call(T a, T x)
621+
{
622+
assert(a >= T(0.0) && x >= T(0.0));
623+
624+
if (x == T(0.0))
625+
return T(0.0);
626+
627+
T ax = (a * log_helper<T>::__call(x)) - x - lgamma_helper<T>::__call(a);
628+
if (ax < T(-709.78271289338399))
629+
return hlsl::mix(T(0.0), T(1.0), a < x);
630+
631+
if (x <= T(1.0) || x <= a)
632+
{
633+
T r2 = a;
634+
T c2 = T(1.0);
635+
T ans2 = T(1.0);
636+
637+
do {
638+
r2 = r2 + T(1.0);
639+
c2 = c2 * x / r2;
640+
ans2 += c2;
641+
} while ((c2 / ans2) > epsilon);
642+
643+
return exp_helper<T>::__call(ax) * ans2 / a;
644+
}
645+
646+
int c = 0;
647+
T y = T(1.0) - a;
648+
T z = x + y + T(1.0);
649+
T p3 = T(1.0);
650+
T q3 = x;
651+
T p2 = x + T(1.0);
652+
T q2 = z * x;
653+
T ans = p2 / q2;
654+
T error;
655+
656+
do {
657+
c++;
658+
y += T(1.0);
659+
z += T(2.0);
660+
T yc = y * c;
661+
T p = (p2 * z) - (p3 * yc);
662+
T q = (q2 * z) - (q3 * yc);
663+
664+
if (q != T(0.0))
665+
{
666+
T nextans = p / q;
667+
error = abs_helper<T>::__call((ans - nextans) / nextans);
668+
ans = nextans;
669+
}
670+
else
671+
{
672+
error = 1;
673+
}
674+
675+
p3 = p2;
676+
p2 = p;
677+
q3 = q2;
678+
q2 = q;
679+
680+
if (abs_helper<T>::__call(p) > big)
681+
{
682+
p3 *= bigInv;
683+
p2 *= bigInv;
684+
q3 *= bigInv;
685+
q2 *= bigInv;
686+
}
687+
} while (error > epsilon);
688+
689+
return T(1.0) - (exp_helper<T>::__call(ax) * ans);
690+
}
691+
};
692+
609693
#ifdef __HLSL_VERSION
610694
// SPIR-V already defines specializations for builtin vector types
611695
#define VECTOR_SPECIALIZATION_CONCEPT concepts::Vectorial<T> && !is_vector_v<T>

0 commit comments

Comments
 (0)