|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import sys |
16 | | -import types |
17 | 16 | import typing as tp |
18 | 17 | from functools import partial |
19 | 18 |
|
20 | 19 | import numpy as np |
21 | 20 | import pandas as pd |
22 | 21 | import pytest |
| 22 | +import torch |
23 | 23 |
|
24 | 24 | try: |
25 | | - import torch |
26 | 25 | from pytorch_lightning import Trainer, seed_everything |
27 | 26 | except ImportError: |
28 | | - torch = types.ModuleType("torch") |
29 | | - torch.Tensor = object # type: ignore |
30 | | - torch.float = object # type: ignore |
31 | 27 | Trainer = object # type: ignore |
32 | 28 |
|
33 | | - def tensor(*args: tp.Any, **kwargs: tp.Any) -> tp.Any: |
34 | | - return object() |
35 | | - |
36 | | - torch.tensor = tensor |
37 | | - |
38 | 29 | from rectools import ExternalIds |
39 | 30 | from rectools.columns import Columns |
40 | 31 | from rectools.dataset import Dataset |
@@ -64,7 +55,7 @@ def tensor(*args: tp.Any, **kwargs: tp.Any) -> tp.Any: |
64 | 55 | pass |
65 | 56 |
|
66 | 57 |
|
67 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
| 58 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
68 | 59 | class TestBERT4RecModel: |
69 | 60 | def setup_method(self) -> None: |
70 | 61 | self._seed_everything() |
@@ -642,7 +633,7 @@ def _collate_fn_train( |
642 | 633 | ) |
643 | 634 |
|
644 | 635 |
|
645 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
| 636 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
646 | 637 | class TestBERT4RecDataPreparator: |
647 | 638 |
|
648 | 639 | def setup_method(self) -> None: |
@@ -822,7 +813,7 @@ def test_get_dataloader_val( |
822 | 813 | assert torch.equal(value, val_batch[key]) |
823 | 814 |
|
824 | 815 |
|
825 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
| 816 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
826 | 817 | class TestBERT4RecModelConfiguration: |
827 | 818 | def setup_method(self) -> None: |
828 | 819 | self._seed_everything() |
|
0 commit comments