|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
17 |
| -import numbers |
18 |
| - |
19 |
| -import numpy as np |
20 |
| - |
21 | 17 | import dpctl
|
22 |
| -import dpctl.memory as dpm |
23 | 18 | import dpctl.tensor as dpt
|
24 | 19 | import dpctl.tensor._tensor_impl as ti
|
25 | 20 | from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
|
26 |
| -from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer |
27 | 21 | from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
|
28 | 22 |
|
29 | 23 | from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
|
| 24 | +from ._scalar_utils import ( |
| 25 | + _get_dtype, |
| 26 | + _get_queue_usm_type, |
| 27 | + _get_shape, |
| 28 | + _validate_dtype, |
| 29 | +) |
30 | 30 | from ._type_utils import (
|
31 |
| - WeakBooleanType, |
32 |
| - WeakComplexType, |
33 |
| - WeakFloatingType, |
34 |
| - WeakIntegralType, |
35 | 31 | _acceptance_fn_default_binary,
|
36 | 32 | _acceptance_fn_default_unary,
|
37 | 33 | _all_data_types,
|
38 | 34 | _find_buf_dtype,
|
39 | 35 | _find_buf_dtype2,
|
40 | 36 | _find_buf_dtype_in_place_op,
|
41 | 37 | _resolve_weak_types,
|
42 |
| - _to_device_supported_dtype, |
43 | 38 | )
|
44 | 39 |
|
45 | 40 |
|
@@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"):
|
289 | 284 | return out
|
290 | 285 |
|
291 | 286 |
|
292 |
| -def _get_queue_usm_type(o): |
293 |
| - """Return SYCL device where object `o` allocated memory, or None.""" |
294 |
| - if isinstance(o, dpt.usm_ndarray): |
295 |
| - return o.sycl_queue, o.usm_type |
296 |
| - elif hasattr(o, "__sycl_usm_array_interface__"): |
297 |
| - try: |
298 |
| - m = dpm.as_usm_memory(o) |
299 |
| - return m.sycl_queue, m.get_usm_type() |
300 |
| - except Exception: |
301 |
| - return None, None |
302 |
| - return None, None |
303 |
| - |
304 |
| - |
305 |
| -def _get_dtype(o, dev): |
306 |
| - if isinstance(o, dpt.usm_ndarray): |
307 |
| - return o.dtype |
308 |
| - if hasattr(o, "__sycl_usm_array_interface__"): |
309 |
| - return dpt.asarray(o).dtype |
310 |
| - if _is_buffer(o): |
311 |
| - host_dt = np.array(o).dtype |
312 |
| - dev_dt = _to_device_supported_dtype(host_dt, dev) |
313 |
| - return dev_dt |
314 |
| - if hasattr(o, "dtype"): |
315 |
| - dev_dt = _to_device_supported_dtype(o.dtype, dev) |
316 |
| - return dev_dt |
317 |
| - if isinstance(o, bool): |
318 |
| - return WeakBooleanType(o) |
319 |
| - if isinstance(o, int): |
320 |
| - return WeakIntegralType(o) |
321 |
| - if isinstance(o, float): |
322 |
| - return WeakFloatingType(o) |
323 |
| - if isinstance(o, complex): |
324 |
| - return WeakComplexType(o) |
325 |
| - return np.object_ |
326 |
| - |
327 |
| - |
328 |
| -def _validate_dtype(dt) -> bool: |
329 |
| - return isinstance( |
330 |
| - dt, |
331 |
| - (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType), |
332 |
| - ) or ( |
333 |
| - isinstance(dt, dpt.dtype) |
334 |
| - and dt |
335 |
| - in [ |
336 |
| - dpt.bool, |
337 |
| - dpt.int8, |
338 |
| - dpt.uint8, |
339 |
| - dpt.int16, |
340 |
| - dpt.uint16, |
341 |
| - dpt.int32, |
342 |
| - dpt.uint32, |
343 |
| - dpt.int64, |
344 |
| - dpt.uint64, |
345 |
| - dpt.float16, |
346 |
| - dpt.float32, |
347 |
| - dpt.float64, |
348 |
| - dpt.complex64, |
349 |
| - dpt.complex128, |
350 |
| - ] |
351 |
| - ) |
352 |
| - |
353 |
| - |
354 |
| -def _get_shape(o): |
355 |
| - if isinstance(o, dpt.usm_ndarray): |
356 |
| - return o.shape |
357 |
| - if _is_buffer(o): |
358 |
| - return memoryview(o).shape |
359 |
| - if isinstance(o, numbers.Number): |
360 |
| - return tuple() |
361 |
| - return getattr(o, "shape", tuple()) |
362 |
| - |
363 |
| - |
364 | 287 | class BinaryElementwiseFunc:
|
365 | 288 | """
|
366 | 289 | Class that implements binary element-wise functions.
|
|
0 commit comments