|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import os |
| 8 | +import warnings |
8 | 9 | import weakref |
9 | 10 | from dataclasses import dataclass |
10 | 11 | from typing import TYPE_CHECKING, Optional, Tuple, Union |
@@ -87,9 +88,19 @@ def _init(obj=None, *, options: Optional[StreamOptions] = None): |
87 | 88 | if obj is not None and options is not None: |
88 | 89 | raise ValueError("obj and options cannot be both specified") |
89 | 90 | if obj is not None: |
90 | | - if not hasattr(obj, "__cuda_stream__"): |
91 | | - raise ValueError |
92 | | - info = obj.__cuda_stream__ |
| 91 | + try: |
| 92 | + info = obj.__cuda_stream__() |
| 93 | + except AttributeError as e: |
| 94 | + raise TypeError(f"{type(obj)} object does not have a '__cuda_stream__' method") from e |
| 95 | + except TypeError: |
| 96 | + info = obj.__cuda_stream__ |
| 97 | + warnings.simplefilter("once", DeprecationWarning) |
| 98 | + warnings.warn( |
| 99 | + "Implementing __cuda_stream__ as an attribute is deprecated; it must be implemented as a method", |
| 100 | + stacklevel=3, |
| 101 | + category=DeprecationWarning, |
| 102 | + ) |
| 103 | + |
93 | 104 | assert info[0] == 0 |
94 | 105 | self._mnff.handle = cuda.CUstream(info[1]) |
95 | 106 | # TODO: check if obj is created under the current context/device |
@@ -132,7 +143,6 @@ def close(self): |
132 | 143 | """ |
133 | 144 | self._mnff.close() |
134 | 145 |
|
135 | | - @property |
136 | 146 | def __cuda_stream__(self) -> Tuple[int, int]: |
137 | 147 | """Return an instance of a __cuda_stream__ protocol.""" |
138 | 148 | return (0, self.handle) |
@@ -279,7 +289,6 @@ def from_handle(handle: int) -> Stream: |
279 | 289 | """ |
280 | 290 |
|
281 | 291 | class _stream_holder: |
282 | | - @property |
283 | 292 | def __cuda_stream__(self): |
284 | 293 | return (0, handle) |
285 | 294 |
|
|
0 commit comments