@@ -25,29 +25,59 @@ import numpy as np
2525
2626from libc.stdlib cimport malloc, free
2727
28- cnp.import_umath()
28+ cimport cpython.pycapsule
2929
30- funcs_dict = {}
30+ cnp.import_umath()
3131
3232ctypedef struct function_info:
3333 cnp.PyUFuncGenericFunction np_function
3434 cnp.PyUFuncGenericFunction mkl_function
3535 int * signature
3636
37- cdef function_info* functions
37+ ctypedef struct functions_struct:
38+ int count
39+ function_info* functions
40+
41+
42+ cdef const char * capsule_name = " functions_cache"
43+
44+
45+ cdef void _capsule_destructor(object caps):
46+ cdef functions_struct* fs
47+
48+ if (caps is None ):
49+ print (" Nothing to destroy" )
50+ return
51+ fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52+ for i in range (fs[0 ].count):
53+ free(fs[0 ].functions[i].signature)
54+ free(fs[0 ].functions)
55+ free(fs)
56+
57+
58+ from threading import local as threading_local
59+ _tls = threading_local()
60+
61+
62+ def _is_tls_initialized ():
63+ return (getattr (_tls, ' initialized' , None ) is not None ) and (_tls.initialized == True )
3864
39- def fill_functions ():
40- global functions
65+
66+ def _initialize_tls ():
67+ cdef functions_struct* fs
68+ cdef int funcs_count
69+
70+ _tls.functions_dict = {}
4171
4272 umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
4373 funcs_count = 0
4474 for umath in umaths:
4575 mkl_umath = getattr (mu, umath)
46- types = mkl_umath.types
47- for type in types:
48- funcs_count = funcs_count + 1
76+ funcs_count = funcs_count + mkl_umath.ntypes
4977
50- functions = < function_info * > malloc(funcs_count * sizeof(function_info))
78+ fs = < functions_struct * > malloc(sizeof(functions_struct))
79+ fs[0 ].count = funcs_count
80+ fs[0 ].functions = < function_info * > malloc(funcs_count * sizeof(function_info))
5181
5282 func_number = 0
5383 for umath in umaths:
@@ -57,28 +87,51 @@ def fill_functions():
5787 c_np_umath = < cnp.ufunc> np_umath
5888 for type in mkl_umath.types:
5989 np_index = np_umath.types.index(type )
60- functions[func_number].np_function = c_np_umath.functions[np_index]
90+ fs[ 0 ]. functions[func_number].np_function = c_np_umath.functions[np_index]
6191 mkl_index = mkl_umath.types.index(type )
62- functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
92+ fs[ 0 ]. functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
6393
6494 nargs = c_mkl_umath.nargs
65- functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
95+ fs[ 0 ]. functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
6696 for i in range (nargs):
67- functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
97+ fs[ 0 ]. functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
6898
69- funcs_dict [(umath, type )] = func_number
99+ _tls.functions_dict [(umath, type )] = func_number
70100 func_number = func_number + 1
71101
102+ _tls.functions_capsule = cpython.pycapsule.PyCapsule_New(< void * > fs, capsule_name, & _capsule_destructor)
103+
104+ _tls.initialized = True
105+
106+
107+ def _get_func_dict ():
108+ if not _is_tls_initialized():
109+ _initialize_tls()
110+ return _tls.functions_dict
72111
73- fill_functions()
74112
75- cdef c_do_patch():
113+ cdef function_info* _get_functions():
114+ cdef function_info* functions
115+ cdef functions_struct* fs
116+
117+ if not _is_tls_initialized():
118+ _initialize_tls()
119+
120+ capsule = _tls.functions_capsule
121+ if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122+ raise ValueError (" Internal Error: invalid capsule stored in TLS" )
123+ fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124+ return fs[0 ].functions
125+
126+
127+ cdef void c_do_patch():
76128 cdef int res
77129 cdef cnp.PyUFuncGenericFunction temp
78130 cdef cnp.PyUFuncGenericFunction function
79131 cdef int * signature
80132
81- global functions
133+ funcs_dict = _get_func_dict()
134+ functions = _get_functions()
82135
83136 for func in funcs_dict:
84137 np_umath = getattr (nu, func[0 ])
@@ -88,13 +141,14 @@ cdef c_do_patch():
88141 res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
89142
90143
91- cdef c_do_unpatch():
144+ cdef void c_do_unpatch():
92145 cdef int res
93146 cdef cnp.PyUFuncGenericFunction temp
94147 cdef cnp.PyUFuncGenericFunction function
95148 cdef int * signature
96149
97- global functions
150+ funcs_dict = _get_func_dict()
151+ functions = _get_functions()
98152
99153 for func in funcs_dict:
100154 np_umath = getattr (nu, func[0 ])
@@ -107,5 +161,6 @@ cdef c_do_unpatch():
107161def do_patch ():
108162 c_do_patch()
109163
164+
110165def do_unpatch ():
111166 c_do_unpatch()
0 commit comments