Skip to content

Commit 31bf682

Browse files
arthus701Max Schanner
and
Max Schanner
authored
Add betainc C implementation (#798)
Co-authored-by: Max Schanner <[email protected]>
1 parent afc1a6c commit 31bf682

File tree

2 files changed

+330
-2
lines changed

2 files changed

+330
-2
lines changed

pytensor/scalar/c_code/incbet.c

+311
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
/* adapted from file incbet.c, obtained from the Cephes library (MIT License)
2+
Cephes Math Library, Release 2.8: June, 2000
3+
Copyright 1984, 1995, 2000 by Stephen L. Moshier
4+
*/
5+
6+
//For GPU support
7+
#ifdef __CUDACC__
8+
#define DEVICE __device__
9+
#else
10+
#define DEVICE
11+
#endif
12+
13+
#include <float.h>
14+
#include <math.h>
15+
#include <stdio.h>
16+
#include <numpy/npy_math.h>
17+
18+
19+
#define MINLOG -170.0
20+
#define MAXLOG +170.0
21+
#define MAXGAM 171.624376956302725
22+
#define EPSILON 2.2204460492503131e-16
23+
24+
DEVICE static double pseries(double, double, double);
25+
DEVICE static double incbcf(double, double, double);
26+
DEVICE static double incbd(double, double, double);
27+
28+
static double big = 4.503599627370496e15;
29+
static double biginv = 2.22044604925031308085e-16;
30+
31+
32+
DEVICE double BetaInc(double a, double b, double x)
33+
{
34+
double xc, y, w, t;
35+
/* check function arguments */
36+
if (a <= 0.0) return NPY_NAN;
37+
if (b <= 0.0) return NPY_NAN;
38+
if (x < 0.0) return NPY_NAN;
39+
if (1.0 < x) return NPY_NAN;
40+
41+
/* some special cases */
42+
if (x == 0.0) return 0.0;
43+
if (x == 1.0) return 1.0;
44+
45+
if ( (b * x) <= 1.0 && x <= 0.95)
46+
{
47+
return pseries(a, b, x);
48+
}
49+
50+
xc = 1.0 - x;
51+
/* reverse a and b if x is greater than the mean */
52+
if (x > (a / (a + b)))
53+
{
54+
t = BetaInc(b, a, xc);
55+
if (t <= EPSILON) return 1.0 - EPSILON;
56+
return 1.0 - t;
57+
}
58+
59+
/* Choose expansion for better convergence. */
60+
y = x * (a+b-2.0) - (a-1.0);
61+
if( y < 0.0 )
62+
w = incbcf( a, b, x );
63+
else
64+
w = incbd( a, b, x ) / xc;
65+
66+
y = a * log(x);
67+
t = b * log(xc);
68+
if( (a+b) < MAXGAM && fabs(y) < MAXLOG && fabs(t) < MAXLOG )
69+
{
70+
t = pow(xc, b);
71+
t *= pow(x, a);
72+
t /= a;
73+
t *= w;
74+
t *= tgamma(a + b) / (tgamma(a) * tgamma(b));
75+
76+
return t;
77+
}
78+
79+
/* Resort to logarithms. */
80+
y += t + lgamma(a+b) - lgamma(a) - lgamma(b);
81+
y += log(w / a);
82+
if( y < MINLOG )
83+
t = 0.0;
84+
else
85+
t = exp(y);
86+
87+
return t;
88+
}
89+
90+
/* Continued fraction expansion #1
91+
* for incomplete beta integral
92+
*/
93+
94+
DEVICE static double incbcf(double a, double b, double x)
95+
{
96+
double xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
97+
double k1, k2, k3, k4, k5, k6, k7, k8;
98+
double r, t, ans, thresh;
99+
int n;
100+
101+
k1 = a;
102+
k2 = a + b;
103+
k3 = a;
104+
k4 = a + 1.0;
105+
k5 = 1.0;
106+
k6 = b - 1.0;
107+
k7 = k4;
108+
k8 = a + 2.0;
109+
110+
pkm2 = 0.0;
111+
qkm2 = 1.0;
112+
pkm1 = 1.0;
113+
qkm1 = 1.0;
114+
ans = 1.0;
115+
r = 1.0;
116+
n = 0;
117+
thresh = 3.0 * EPSILON;
118+
do
119+
{
120+
121+
xk = -( x * k1 * k2 ) / ( k3 * k4 );
122+
pk = pkm1 + pkm2 * xk;
123+
qk = qkm1 + qkm2 * xk;
124+
pkm2 = pkm1;
125+
pkm1 = pk;
126+
qkm2 = qkm1;
127+
qkm1 = qk;
128+
129+
xk = ( x * k5 * k6 ) / ( k7 * k8 );
130+
pk = pkm1 + pkm2 * xk;
131+
qk = qkm1 + qkm2 * xk;
132+
pkm2 = pkm1;
133+
pkm1 = pk;
134+
qkm2 = qkm1;
135+
qkm1 = qk;
136+
137+
if( qk != 0.0 )
138+
r = pk/qk;
139+
if( r != 0.0 )
140+
{
141+
t = fabs( (ans - r) / r );
142+
ans = r;
143+
}
144+
else
145+
t = 1.0;
146+
147+
if( t < thresh )
148+
break;
149+
150+
k1 += 1.0;
151+
k2 += 1.0;
152+
k3 += 2.0;
153+
k4 += 2.0;
154+
k5 += 1.0;
155+
k6 -= 1.0;
156+
k7 += 2.0;
157+
k8 += 2.0;
158+
159+
if( (fabs(qk) + fabs(pk)) > big )
160+
{
161+
pkm2 *= biginv;
162+
pkm1 *= biginv;
163+
qkm2 *= biginv;
164+
qkm1 *= biginv;
165+
}
166+
if( (fabs(qk) < biginv) || (fabs(pk) < biginv) )
167+
{
168+
pkm2 *= big;
169+
pkm1 *= big;
170+
qkm2 *= big;
171+
qkm1 *= big;
172+
}
173+
}
174+
while( ++n < 300 );
175+
176+
return ans;
177+
}
178+
179+
/* Continued fraction expansion #2
180+
* for incomplete beta integral
181+
*/
182+
183+
DEVICE static double incbd(double a, double b, double x)
184+
{
185+
double xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
186+
double k1, k2, k3, k4, k5, k6, k7, k8;
187+
double r, t, ans, z, thresh;
188+
int n;
189+
190+
k1 = a;
191+
k2 = b - 1.0;
192+
k3 = a;
193+
k4 = a + 1.0;
194+
k5 = 1.0;
195+
k6 = a + b;
196+
k7 = a + 1.0;;
197+
k8 = a + 2.0;
198+
199+
pkm2 = 0.0;
200+
qkm2 = 1.0;
201+
pkm1 = 1.0;
202+
qkm1 = 1.0;
203+
z = x / (1.0-x);
204+
ans = 1.0;
205+
r = 1.0;
206+
n = 0;
207+
thresh = 3.0 * EPSILON;
208+
do
209+
{
210+
211+
xk = -( z * k1 * k2 ) / ( k3 * k4 );
212+
pk = pkm1 + pkm2 * xk;
213+
qk = qkm1 + qkm2 * xk;
214+
pkm2 = pkm1;
215+
pkm1 = pk;
216+
qkm2 = qkm1;
217+
qkm1 = qk;
218+
219+
xk = ( z * k5 * k6 ) / ( k7 * k8 );
220+
pk = pkm1 + pkm2 * xk;
221+
qk = qkm1 + qkm2 * xk;
222+
pkm2 = pkm1;
223+
pkm1 = pk;
224+
qkm2 = qkm1;
225+
qkm1 = qk;
226+
227+
if( qk != 0 )
228+
r = pk/qk;
229+
if( r != 0 )
230+
{
231+
t = fabs( (ans - r) / r );
232+
ans = r;
233+
}
234+
else
235+
t = 1.0;
236+
237+
if( t < thresh )
238+
break;
239+
240+
k1 += 1.0;
241+
k2 -= 1.0;
242+
k3 += 2.0;
243+
k4 += 2.0;
244+
k5 += 1.0;
245+
k6 += 1.0;
246+
k7 += 2.0;
247+
k8 += 2.0;
248+
249+
if( (fabs(qk) + fabs(pk)) > big )
250+
{
251+
pkm2 *= biginv;
252+
pkm1 *= biginv;
253+
qkm2 *= biginv;
254+
qkm1 *= biginv;
255+
}
256+
if( (fabs(qk) < biginv) || (fabs(pk) < biginv) )
257+
{
258+
pkm2 *= big;
259+
pkm1 *= big;
260+
qkm2 *= big;
261+
qkm1 *= big;
262+
}
263+
}
264+
while( ++n < 300 );
265+
266+
return ans;
267+
}
268+
269+
270+
/* Power series for incomplete beta integral.
271+
Use when b*x is small and x not too close to 1. */
272+
273+
DEVICE static double pseries(double a, double b, double x)
274+
{
275+
double s, t, u, v, n, t1, z, ai;
276+
277+
ai = 1.0 / a;
278+
u = (1.0 - b) * x;
279+
v = u / (a + 1.0);
280+
t1 = v;
281+
t = u;
282+
n = 2.0;
283+
s = 0.0;
284+
z = EPSILON * ai;
285+
while( fabs(v) > z )
286+
{
287+
u = (n - b) * x / n;
288+
t *= u;
289+
v = t / (a + n);
290+
s += v;
291+
n += 1.0;
292+
}
293+
s += t1;
294+
s += ai;
295+
296+
u = a * log(x);
297+
if( (a+b) < MAXGAM && fabs(u) < MAXLOG )
298+
{
299+
t = tgamma(a + b) / (tgamma(a) * tgamma(b));
300+
s = s * t * pow(x,a);
301+
}
302+
else
303+
{
304+
t = lgamma(a + b) - lgamma(a) - lgamma(b) + u + log(s);
305+
if( t < MINLOG )
306+
s = 0.0;
307+
else
308+
s = exp(t);
309+
}
310+
return s;
311+
}

pytensor/scalar/math.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1495,8 +1495,25 @@ def grad(self, inp, grads):
14951495
),
14961496
]
14971497

1498-
def c_code(self, *args, **kwargs):
1499-
raise NotImplementedError()
1498+
def c_support_code(self, **kwargs):
1499+
with open(os.path.join(os.path.dirname(__file__), "c_code", "incbet.c")) as f:
1500+
raw = f.read()
1501+
return raw
1502+
1503+
def c_code(self, node, name, inp, out, sub):
1504+
(a, b, x) = inp
1505+
(z,) = out
1506+
if (
1507+
node.inputs[0].type in float_types
1508+
and node.inputs[1].type in float_types
1509+
and node.inputs[2].type in float_types
1510+
):
1511+
return f"""{z} = BetaInc({a}, {b}, {x});"""
1512+
1513+
raise NotImplementedError("type not supported", type)
1514+
1515+
def c_code_cache_version(self):
1516+
return (1,)
15001517

15011518

15021519
betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")

0 commit comments

Comments
 (0)