2929import numpy as np
3030from numba .extending import get_cython_function_address
3131
32+ # TODO: these are reasonable defaults, but could
33+ # be made settable via a control dict
3234_HYP2F1_TOL = np .sqrt (np .finfo (np .float64 ).eps )
33- _HYP2F1_CHECK = np .sqrt (_HYP2F1_TOL )
3435_HYP2F1_MAXTERM = int (1e6 )
3536
3637_PTR = ctypes .POINTER
@@ -115,7 +116,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z):
115116 See Eq. 6 in https://doi.org/10.1016/j.cpc.2007.11.007
116117 """
117118 if z == 0.0 :
118- return np .abs (f1 - a * b / c ) < _HYP2F1_CHECK
119+ return np .abs (f1 - a * b / c ) < _HYP2F1_TOL
119120 u = c - (a + b + 1 ) * z
120121 v = a * b
121122 w = z * (1 - z )
@@ -124,7 +125,7 @@ def _is_valid_2f1(f1, f2, a, b, c, z):
124125 numer = np .abs (u * f1 - v )
125126 else :
126127 numer = np .abs (f2 + u / w * f1 - v / w )
127- return numer / denom < _HYP2F1_CHECK
128+ return numer / denom < _HYP2F1_TOL
128129
129130
130131@numba .njit ("UniTuple(float64, 7)(float64, float64, float64, float64)" )
@@ -255,7 +256,7 @@ def _hyp2f1_recurrence(a, b, c, z):
255256
256257
257258@numba .njit (
258- "UniTuple(float64, 6 )(float64, float64, float64, float64, float64, float64)"
259+ "UniTuple(float64, 7 )(float64, float64, float64, float64, float64, float64)"
259260)
260261def _hyp2f1_dlmf1583_first (a_i , b_i , a_j , b_j , y , mu ):
261262 """
@@ -287,21 +288,26 @@ def _hyp2f1_dlmf1583_first(a_i, b_i, a_j, b_j, y, mu):
287288 )
288289
289290 # 2F1(a, -y; c; z) via backwards recurrence
290- val , sign , da , _ , dc , dz , _ = _hyp2f1_recurrence (a , y , c , z )
291+ val , sign , da , _ , dc , dz , d2z = _hyp2f1_recurrence (a , y , c , z )
291292
292293 # map gradient to parameters
293294 da_i = dc - _digamma (a_i + a_j ) + _digamma (a_i )
294295 da_j = da + dc - np .log (s ) + _digamma (a_j + y + 1 ) - _digamma (a_i + a_j )
295296 db_i = dz / (b_j - mu ) + a_j / (mu + b_i )
296297 db_j = dz * (1 - z ) / (b_j - mu ) - a_j / s / (mu + b_i )
297298
299+ # needed to verify result
300+ d2b_j = (1 - z ) / (b_j - mu ) ** 2 * (d2z * (1 - z ) - 2 * dz * (1 + a_j )) + (
301+ 1 + a_j
302+ ) * a_j / (b_j - mu ) ** 2
303+
298304 val += scale
299305
300- return val , sign , da_i , db_i , da_j , db_j
306+ return val , sign , da_i , db_i , da_j , db_j , d2b_j
301307
302308
303309@numba .njit (
304- "UniTuple(float64, 6 )(float64, float64, float64, float64, float64, float64)"
310+ "UniTuple(float64, 7 )(float64, float64, float64, float64, float64, float64)"
305311)
306312def _hyp2f1_dlmf1583_second (a_i , b_i , a_j , b_j , y , mu ):
307313 """
@@ -320,18 +326,24 @@ def _hyp2f1_dlmf1583_second(a_i, b_i, a_j, b_j, y, mu):
320326 )
321327
322328 # 2F1(a, y+1; c; z) via series expansion
323- val , sign , da , _ , dc , dz , _ = _hyp2f1_taylor_series (a , y + 1 , c , z )
329+ val , sign , da , _ , dc , dz , d2z = _hyp2f1_taylor_series (a , y + 1 , c , z )
324330
325331 # map gradient to parameters
326332 da_i = da + np .log (z ) + dc + _digamma (a_i ) - _digamma (a_i + y + 1 )
327333 da_j = da + np .log (z ) + _digamma (a_j + y + 1 ) - _digamma (a_j )
328334 db_i = (1 - z ) * (dz + a / z ) / (b_i + b_j )
329335 db_j = - z * (dz + a / z ) / (b_i + b_j )
330336
337+ # needed to verify result
338+ d2b_j = (
339+ z / (b_i + b_j ) ** 2 * (d2z * z + 2 * dz * (1 + a ))
340+ + a * (1 + a ) / (b_i + b_j ) ** 2
341+ )
342+
331343 sign *= (- 1 ) ** (y + 1 )
332344 val += scale
333345
334- return val , sign , da_i , db_i , da_j , db_j
346+ return val , sign , da_i , db_i , da_j , db_j , d2b_j
335347
336348
337349@numba .njit (
@@ -345,18 +357,14 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
345357 assert 0 <= mu <= b_j
346358 assert y >= 0 and y % 1.0 == 0.0
347359
348- f_1 , s_1 , da_i_1 , db_i_1 , da_j_1 , db_j_1 = _hyp2f1_dlmf1583_first (
360+ f_1 , s_1 , da_i_1 , db_i_1 , da_j_1 , db_j_1 , d2b_j_1 = _hyp2f1_dlmf1583_first (
349361 a_i , b_i , a_j , b_j , y , mu
350362 )
351363
352- f_2 , s_2 , da_i_2 , db_i_2 , da_j_2 , db_j_2 = _hyp2f1_dlmf1583_second (
364+ f_2 , s_2 , da_i_2 , db_i_2 , da_j_2 , db_j_2 , d2b_j_2 = _hyp2f1_dlmf1583_second (
353365 a_i , b_i , a_j , b_j , y , mu
354366 )
355367
356- if np .abs (f_1 - f_2 ) < _HYP2F1_TOL :
357- # TODO: detect a priori if this will occur
358- raise Invalid2F1 ("Singular hypergeometric function" )
359-
360368 f_0 = max (f_1 , f_2 )
361369 f_1 = np .exp (f_1 - f_0 ) * s_1
362370 f_2 = np .exp (f_2 - f_0 ) * s_2
@@ -366,10 +374,22 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
366374 db_i = (db_i_1 * f_1 + db_i_2 * f_2 ) / f
367375 da_j = (da_j_1 * f_1 + da_j_2 * f_2 ) / f
368376 db_j = (db_j_1 * f_1 + db_j_2 * f_2 ) / f
377+ d2b_j = (d2b_j_1 * f_1 + d2b_j_2 * f_2 ) / f
369378
370379 sign = np .sign (f )
371380 val = np .log (np .abs (f )) + f_0
372381
382+ # use first/second derivatives to check that result is non-singular
383+ dz = - db_j * (mu + b_i )
384+ d2z = d2b_j * (mu + b_i ) ** 2
385+ if (
386+ not _is_valid_2f1 (
387+ dz , d2z , a_j , a_i + a_j + y , a_j + y + 1 , (mu - b_j ) / (mu + b_i )
388+ )
389+ or sign <= 0
390+ ):
391+ raise Invalid2F1 ("Hypergeometric series is singular" )
392+
373393 return val , sign , da_i , db_i , da_j , db_j
374394
375395
0 commit comments