@@ -25,35 +25,100 @@ import numpy as np
2525
2626from libc.stdlib cimport malloc, free
2727
28- cimport cpython.pycapsule
29-
3028cnp.import_umath()
3129
30+
3231ctypedef struct function_info:
33- cnp.PyUFuncGenericFunction np_function
34- cnp.PyUFuncGenericFunction mkl_function
32+ cnp.PyUFuncGenericFunction original_function
33+ cnp.PyUFuncGenericFunction patch_function
3534 int * signature
3635
37- ctypedef struct functions_struct:
38- int count
39- function_info* functions
40-
41-
42- cdef const char * capsule_name = " functions_cache"
43-
44-
45- cdef void _capsule_destructor(object caps):
46- cdef functions_struct* fs
47-
48- if (caps is None ):
49- print (" Nothing to destroy" )
50- return
51- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52- for i in range (fs[0 ].count):
53- free(fs[0 ].functions[i].signature)
54- free(fs[0 ].functions)
55- free(fs)
5636
37+ cdef class patch:
38+ cdef int functions_count
39+ cdef function_info* functions
40+ cdef bint _is_patched
41+
42+ functions_dict = dict ()
43+
44+ def __cinit__ (self ):
45+ cdef int pi, oi
46+
47+ self ._is_patched = False
48+
49+ umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
50+ self .functions_count = 0
51+ for umath in umaths:
52+ mkl_umath = getattr (mu, umath)
53+ self .functions_count = self .functions_count + mkl_umath.ntypes
54+
55+ self .functions = < function_info * > malloc(self .functions_count * sizeof(function_info))
56+
57+ func_number = 0
58+ for umath in umaths:
59+ patch_umath = getattr (mu, umath)
60+ c_patch_umath = < cnp.ufunc> patch_umath
61+ c_orig_umath = < cnp.ufunc> getattr (nu, umath)
62+ nargs = c_patch_umath.nargs
63+ for pi in range (c_patch_umath.ntypes):
64+ oi = 0
65+ while oi < c_orig_umath.ntypes:
66+ found = True
67+ for i in range (c_patch_umath.nargs):
68+ if c_patch_umath.types[pi * nargs + i] != c_orig_umath.types[oi * nargs + i]:
69+ found = False
70+ break
71+ if found == True :
72+ break
73+ oi = oi + 1
74+ if oi < c_orig_umath.ntypes:
75+ self .functions[func_number].original_function = c_orig_umath.functions[oi]
76+ self .functions[func_number].patch_function = c_patch_umath.functions[pi]
77+ self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
78+ for i in range (nargs):
79+ self .functions[func_number].signature[i] = c_patch_umath.types[pi * nargs + i]
80+ self .functions_dict[(umath, patch_umath.types[pi])] = func_number
81+ func_number = func_number + 1
82+ else :
83+ raise RuntimeError (" Unable to find original function for: " + umath + " " + patch_umath.types[pi])
84+
85+ def __dealloc__ (self ):
86+ for i in range (self .functions_count):
87+ free(self .functions[i].signature)
88+ free(self .functions)
89+
90+ def do_patch (self ):
91+ cdef int res
92+ cdef cnp.PyUFuncGenericFunction temp
93+ cdef cnp.PyUFuncGenericFunction function
94+ cdef int * signature
95+
96+ for func in self .functions_dict:
97+ np_umath = getattr (nu, func[0 ])
98+ index = self .functions_dict[func]
99+ function = self .functions[index].patch_function
100+ signature = self .functions[index].signature
101+ res = cnp.PyUFunc_ReplaceLoopBySignature(< cnp.ufunc> np_umath, function, signature, & temp)
102+
103+ self ._is_patched = True
104+
105+ def do_unpatch (self ):
106+ cdef int res
107+ cdef cnp.PyUFuncGenericFunction temp
108+ cdef cnp.PyUFuncGenericFunction function
109+ cdef int * signature
110+
111+ for func in self .functions_dict:
112+ np_umath = getattr (nu, func[0 ])
113+ index = self .functions_dict[func]
114+ function = self .functions[index].original_function
115+ signature = self .functions[index].signature
116+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
117+
118+ self ._is_patched = False
119+
120+ def is_patched (self ):
121+ return self ._is_patched
57122
58123from threading import local as threading_local
59124_tls = threading_local()
@@ -64,103 +129,43 @@ def _is_tls_initialized():
64129
65130
66131def _initialize_tls ():
67- cdef functions_struct* fs
68- cdef int funcs_count
69-
70- _tls.functions_dict = {}
71-
72- umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
73- funcs_count = 0
74- for umath in umaths:
75- mkl_umath = getattr (mu, umath)
76- funcs_count = funcs_count + mkl_umath.ntypes
77-
78- fs = < functions_struct * > malloc(sizeof(functions_struct))
79- fs[0 ].count = funcs_count
80- fs[0 ].functions = < function_info * > malloc(funcs_count * sizeof(function_info))
81-
82- func_number = 0
83- for umath in umaths:
84- mkl_umath = getattr (mu, umath)
85- np_umath = getattr (nu, umath)
86- c_mkl_umath = < cnp.ufunc> mkl_umath
87- c_np_umath = < cnp.ufunc> np_umath
88- for type in mkl_umath.types:
89- np_index = np_umath.types.index(type )
90- fs[0 ].functions[func_number].np_function = c_np_umath.functions[np_index]
91- mkl_index = mkl_umath.types.index(type )
92- fs[0 ].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
93-
94- nargs = c_mkl_umath.nargs
95- fs[0 ].functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
96- for i in range (nargs):
97- fs[0 ].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
98-
99- _tls.functions_dict[(umath, type )] = func_number
100- func_number = func_number + 1
101-
102- _tls.functions_capsule = cpython.pycapsule.PyCapsule_New(< void * > fs, capsule_name, & _capsule_destructor)
103-
132+ _tls.patch = patch()
104133 _tls.initialized = True
105134
106135
107- def _get_func_dict ():
136+ def use_in_numpy ():
137+ '''
138+ Enables using of mkl_umath in Numpy.
139+ '''
108140 if not _is_tls_initialized():
109141 _initialize_tls()
110- return _tls.functions_dict
142+ _tls.patch.do_patch()
111143
112144
113- cdef function_info * _get_functions ():
114- cdef function_info * functions
115- cdef functions_struct * fs
116-
145+ def restore ():
146+ '''
147+ Disables using of mkl_umath in Numpy.
148+ '''
117149 if not _is_tls_initialized():
118150 _initialize_tls()
151+ _tls.patch.do_unpatch()
119152
120- capsule = _tls.functions_capsule
121- if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122- raise ValueError (" Internal Error: invalid capsule stored in TLS" )
123- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124- return fs[0 ].functions
125-
126-
127- cdef void c_do_patch():
128- cdef int res
129- cdef cnp.PyUFuncGenericFunction temp
130- cdef cnp.PyUFuncGenericFunction function
131- cdef int * signature
132-
133- funcs_dict = _get_func_dict()
134- functions = _get_functions()
135-
136- for func in funcs_dict:
137- np_umath = getattr (nu, func[0 ])
138- index = funcs_dict[func]
139- function = functions[index].mkl_function
140- signature = functions[index].signature
141- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
142-
143-
144- cdef void c_do_unpatch():
145- cdef int res
146- cdef cnp.PyUFuncGenericFunction temp
147- cdef cnp.PyUFuncGenericFunction function
148- cdef int * signature
149-
150- funcs_dict = _get_func_dict()
151- functions = _get_functions()
152-
153- for func in funcs_dict:
154- np_umath = getattr (nu, func[0 ])
155- index = funcs_dict[func]
156- function = functions[index].np_function
157- signature = functions[index].signature
158- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
159153
154+ def is_patched ():
155+ '''
156+ Returns whether Numpy has been patched with mkl_umath.
157+ '''
158+ if not _is_tls_initialized():
159+ _initialize_tls()
160+ _tls.patch.is_patched()
160161
161- def do_patch ():
162- c_do_patch()
162+ from contextlib import ContextDecorator
163163
164+ class mkl_umath (ContextDecorator ):
165+ def __enter__ (self ):
166+ use_in_numpy()
167+ return self
164168
165- def do_unpatch ():
166- c_do_unpatch()
169+ def __exit__ (self , *exc ):
170+ restore()
171+ return False
0 commit comments