@@ -13,11 +13,7 @@ from libcpp cimport vector
1313
1414import ctypes
1515
16- # this might be an unnecessary assumption that NumPy does not exist...
17- try :
18- import numpy
19- except ImportError :
20- numpy = None
16+ import numpy
2117
2218from cuda.core._memory import Buffer
2319
@@ -43,7 +39,32 @@ ctypedef fused supported_type:
4339 cpp_double_complex
4440
4541
46- # TODO: cache ctypes/numpy type objects to avoid attribute access
42+ # cache ctypes/numpy type objects to avoid attribute access
43+ cdef object ctypes_bool = ctypes.c_bool
44+ cdef object ctypes_int8 = ctypes.c_int8
45+ cdef object ctypes_int16 = ctypes.c_int16
46+ cdef object ctypes_int32 = ctypes.c_int32
47+ cdef object ctypes_int64 = ctypes.c_int64
48+ cdef object ctypes_uint8 = ctypes.c_uint8
49+ cdef object ctypes_uint16 = ctypes.c_uint16
50+ cdef object ctypes_uint32 = ctypes.c_uint32
51+ cdef object ctypes_uint64 = ctypes.c_uint64
52+ cdef object ctypes_float = ctypes.c_float
53+ cdef object ctypes_double = ctypes.c_double
54+ cdef object numpy_bool = numpy.bool_
55+ cdef object numpy_int8 = numpy.int8
56+ cdef object numpy_int16 = numpy.int16
57+ cdef object numpy_int32 = numpy.int32
58+ cdef object numpy_int64 = numpy.int64
59+ cdef object numpy_uint8 = numpy.uint8
60+ cdef object numpy_uint16 = numpy.uint16
61+ cdef object numpy_uint32 = numpy.uint32
62+ cdef object numpy_uint64 = numpy.uint64
63+ cdef object numpy_float16 = numpy.float16
64+ cdef object numpy_float32 = numpy.float32
65+ cdef object numpy_float64 = numpy.float64
66+ cdef object numpy_complex64 = numpy.complex64
67+ cdef object numpy_complex128 = numpy.complex128
4768
4869
4970# limitation due to cython/cython#534
@@ -76,27 +97,27 @@ cdef inline int prepare_ctypes_arg(
7697 vector.vector[void * ]& data_addresses,
7798 arg,
7899 const size_t idx) except - 1 :
79- if isinstance (arg, ctypes.c_bool ):
100+ if isinstance (arg, ctypes_bool ):
80101 return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
81- elif isinstance (arg, ctypes.c_int8 ):
102+ elif isinstance (arg, ctypes_int8 ):
82103 return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
83- elif isinstance (arg, ctypes.c_int16 ):
104+ elif isinstance (arg, ctypes_int16 ):
84105 return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
85- elif isinstance (arg, ctypes.c_int32 ):
106+ elif isinstance (arg, ctypes_int32 ):
86107 return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
87- elif isinstance (arg, ctypes.c_int64 ):
108+ elif isinstance (arg, ctypes_int64 ):
88109 return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
89- elif isinstance (arg, ctypes.c_uint8 ):
110+ elif isinstance (arg, ctypes_uint8 ):
90111 return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
91- elif isinstance (arg, ctypes.c_uint16 ):
112+ elif isinstance (arg, ctypes_uint16 ):
92113 return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
93- elif isinstance (arg, ctypes.c_uint32 ):
114+ elif isinstance (arg, ctypes_uint32 ):
94115 return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
95- elif isinstance (arg, ctypes.c_uint64 ):
116+ elif isinstance (arg, ctypes_uint64 ):
96117 return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
97- elif isinstance (arg, ctypes.c_float ):
118+ elif isinstance (arg, ctypes_float ):
98119 return prepare_arg[float ](data, data_addresses, arg.value, idx)
99- elif isinstance (arg, ctypes.c_double ):
120+ elif isinstance (arg, ctypes_double ):
100121 return prepare_arg[double ](data, data_addresses, arg.value, idx)
101122 else :
102123 return 1
@@ -107,37 +128,34 @@ cdef inline int prepare_numpy_arg(
107128 vector.vector[void * ]& data_addresses,
108129 arg,
109130 const size_t idx) except - 1 :
110- if not numpy:
111- return 1
112-
113- if isinstance (arg, numpy.bool_):
131+ if isinstance (arg, numpy_bool):
114132 return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
115- elif isinstance (arg, numpy.int8 ):
133+ elif isinstance (arg, numpy_int8 ):
116134 return prepare_arg[int8_t](data, data_addresses, arg, idx)
117- elif isinstance (arg, numpy.int16 ):
135+ elif isinstance (arg, numpy_int16 ):
118136 return prepare_arg[int16_t](data, data_addresses, arg, idx)
119- elif isinstance (arg, numpy.int32 ):
137+ elif isinstance (arg, numpy_int32 ):
120138 return prepare_arg[int32_t](data, data_addresses, arg, idx)
121- elif isinstance (arg, numpy.int64 ):
139+ elif isinstance (arg, numpy_int64 ):
122140 return prepare_arg[int64_t](data, data_addresses, arg, idx)
123- elif isinstance (arg, numpy.uint8 ):
141+ elif isinstance (arg, numpy_uint8 ):
124142 return prepare_arg[uint8_t](data, data_addresses, arg, idx)
125- elif isinstance (arg, numpy.uint16 ):
143+ elif isinstance (arg, numpy_uint16 ):
126144 return prepare_arg[uint16_t](data, data_addresses, arg, idx)
127- elif isinstance (arg, numpy.uint32 ):
145+ elif isinstance (arg, numpy_uint32 ):
128146 return prepare_arg[uint32_t](data, data_addresses, arg, idx)
129- elif isinstance (arg, numpy.uint64 ):
147+ elif isinstance (arg, numpy_uint64 ):
130148 return prepare_arg[uint64_t](data, data_addresses, arg, idx)
131- elif isinstance (arg, numpy.float16 ):
149+ elif isinstance (arg, numpy_float16 ):
132150 # use int16 as a proxy
133151 return prepare_arg[int16_t](data, data_addresses, arg, idx)
134- elif isinstance (arg, numpy.float32 ):
152+ elif isinstance (arg, numpy_float32 ):
135153 return prepare_arg[float ](data, data_addresses, arg, idx)
136- elif isinstance (arg, numpy.float64 ):
154+ elif isinstance (arg, numpy_float64 ):
137155 return prepare_arg[double ](data, data_addresses, arg, idx)
138- elif isinstance (arg, numpy.complex64 ):
156+ elif isinstance (arg, numpy_complex64 ):
139157 return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
140- elif isinstance (arg, numpy.complex128 ):
158+ elif isinstance (arg, numpy_complex128 ):
141159 return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
142160 else :
143161 return 1
@@ -185,9 +203,9 @@ cdef class ParamHolder:
185203 continue
186204
187205 not_prepared = prepare_numpy_arg(self .data, self .data_addresses, arg, i)
188- if not_prepared ! = 0 :
206+ if not_prepared:
189207 not_prepared = prepare_ctypes_arg(self .data, self .data_addresses, arg, i)
190- if not_prepared ! = 0 :
208+ if not_prepared:
191209 # TODO: support ctypes/numpy struct
192210 raise TypeError
193211
0 commit comments