@@ -98,6 +98,8 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
9898struct lgamma_helper;
9999template<typename T NBL_STRUCT_CONSTRAINABLE>
100100struct beta_helper;
101+ template<typename T NBL_STRUCT_CONSTRAINABLE>
102+ struct gamma_helper;
101103
102104#ifdef __HLSL_VERSION
103105
@@ -606,6 +608,88 @@ struct beta_helper<T NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T>) >
606608 }
607609};
608610
611+ // incomplete gamma function
612+ template<typename T>
613+ NBL_PARTIAL_REQ_TOP (concepts::FloatingPointScalar<T>)
614+ struct gamma_helper<T NBL_PARTIAL_REQ_BOT (concepts::FloatingPointScalar<T>) >
615+ {
616+ NBL_CONSTEXPR_STATIC_INLINE T epsilon = 1e-15 ;
617+ NBL_CONSTEXPR_STATIC_INLINE T big = 4503599627370496.0 ;
618+ NBL_CONSTEXPR_STATIC_INLINE T bigInv = 2. 22044604925031308085e-16 ;
619+
620+ static T __call (T a, T x)
621+ {
622+ assert (a >= T (0.0 ) && x >= T (0.0 ));
623+
624+ if (x == T (0.0 ))
625+ return T (0.0 );
626+
627+ T ax = (a * log_helper<T>::__call (x)) - x - lgamma_helper<T>::__call (a);
628+ if (ax < T (-709.78271289338399 ))
629+ return hlsl::mix (T (0.0 ), T (1.0 ), a < x);
630+
631+ if (x <= T (1.0 ) || x <= a)
632+ {
633+ T r2 = a;
634+ T c2 = T (1.0 );
635+ T ans2 = T (1.0 );
636+
637+ do {
638+ r2 = r2 + T (1.0 );
639+ c2 = c2 * x / r2;
640+ ans2 += c2;
641+ } while ((c2 / ans2) > epsilon);
642+
643+ return exp_helper<T>::__call (ax) * ans2 / a;
644+ }
645+
646+ int c = 0 ;
647+ T y = T (1.0 ) - a;
648+ T z = x + y + T (1.0 );
649+ T p3 = T (1.0 );
650+ T q3 = x;
651+ T p2 = x + T (1.0 );
652+ T q2 = z * x;
653+ T ans = p2 / q2;
654+ T error;
655+
656+ do {
657+ c++;
658+ y += T (1.0 );
659+ z += T (2.0 );
660+ T yc = y * c;
661+ T p = (p2 * z) - (p3 * yc);
662+ T q = (q2 * z) - (q3 * yc);
663+
664+ if (q != T (0.0 ))
665+ {
666+ T nextans = p / q;
667+ error = abs_helper<T>::__call ((ans - nextans) / nextans);
668+ ans = nextans;
669+ }
670+ else
671+ {
672+ error = 1 ;
673+ }
674+
675+ p3 = p2;
676+ p2 = p;
677+ q3 = q2;
678+ q2 = q;
679+
680+ if (abs_helper<T>::__call (p) > big)
681+ {
682+ p3 *= bigInv;
683+ p2 *= bigInv;
684+ q3 *= bigInv;
685+ q2 *= bigInv;
686+ }
687+ } while (error > epsilon);
688+
689+ return T (1.0 ) - (exp_helper<T>::__call (ax) * ans);
690+ }
691+ };
692+
609693#ifdef __HLSL_VERSION
610694// SPIR-V already defines specializations for builtin vector types
611695#define VECTOR_SPECIALIZATION_CONCEPT concepts::Vectorial<T> && !is_vector_v<T>
0 commit comments