|
14 | 14 |
|
15 | 15 | # pylint: disable=too-many-lines |
16 | 16 |
|
| 17 | +import sys |
17 | 18 | import typing as tp |
18 | 19 | from functools import partial |
19 | 20 |
|
20 | 21 | import numpy as np |
21 | 22 | import pandas as pd |
22 | 23 | import pytest |
23 | | -import torch |
24 | | -from pytorch_lightning import Trainer, seed_everything |
| 24 | + |
| 25 | +try: |
| 26 | + import torch |
| 27 | + from pytorch_lightning import Trainer, seed_everything |
| 28 | +except ImportError: |
| 29 | + Trainer: tp.Any = object |
25 | 30 |
|
26 | 31 | from rectools import ExternalIds |
27 | 32 | from rectools.columns import Columns |
28 | 33 | from rectools.dataset import Dataset, IdMap, Interactions |
29 | | -from rectools.models import SASRecModel |
30 | | -from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
31 | | -from rectools.models.nn.transformers.base import ( |
32 | | - LearnableInversePositionalEncoding, |
33 | | - TrainerCallable, |
34 | | - TransformerLightningModule, |
35 | | - TransformerTorchBackbone, |
36 | | -) |
37 | | -from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers |
| 34 | + |
| 35 | +try: |
| 36 | + from rectools.models import SASRecModel |
| 37 | + from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
| 38 | + from rectools.models.nn.transformers.base import ( |
| 39 | + LearnableInversePositionalEncoding, |
| 40 | + TrainerCallable, |
| 41 | + TransformerLightningModule, |
| 42 | + TransformerTorchBackbone, |
| 43 | + ) |
| 44 | + from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers |
| 45 | +except ImportError: |
| 46 | + TrainerCallable: tp.Any = object |
38 | 47 | from tests.models.data import DATASET |
39 | 48 | from tests.models.utils import ( |
40 | 49 | assert_default_config_and_default_model_params_are_the_same, |
41 | 50 | assert_second_fit_refits_model, |
42 | 51 | ) |
43 | 52 | from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal |
44 | 53 |
|
45 | | -from .utils import custom_trainer, leave_one_out_mask |
46 | | - |
| 54 | +try: |
| 55 | + from .utils import custom_trainer, leave_one_out_mask |
| 56 | +except NameError: |
| 57 | + pass |
47 | 58 |
|
| 59 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
48 | 60 | class TestSASRecModel: |
49 | 61 | def setup_method(self) -> None: |
50 | 62 | self._seed_everything() |
@@ -698,6 +710,7 @@ def test_torch_model(self, dataset: Dataset) -> None: |
698 | 710 | assert isinstance(model.torch_model, TransformerTorchBackbone) |
699 | 711 |
|
700 | 712 |
|
| 713 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
701 | 714 | class TestSASRecDataPreparator: |
702 | 715 |
|
703 | 716 | def setup_method(self) -> None: |
@@ -864,6 +877,7 @@ def test_get_dataloader_recommend( |
864 | 877 | assert torch.equal(value, recommend_batch[key]) |
865 | 878 |
|
866 | 879 |
|
| 880 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
867 | 881 | class TestSASRecModelConfiguration: |
868 | 882 | def setup_method(self) -> None: |
869 | 883 | self._seed_everything() |
|
0 commit comments