|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import sys |
15 | 16 | import typing as tp |
16 | 17 | from functools import partial |
17 | 18 |
|
18 | 19 | import numpy as np |
19 | 20 | import pandas as pd |
20 | 21 | import pytest |
21 | | -import torch |
22 | | -from pytorch_lightning import Trainer, seed_everything |
| 22 | + |
| 23 | +try: |
| 24 | + import torch |
| 25 | + from pytorch_lightning import Trainer, seed_everything |
| 26 | +except ImportError: |
| 27 | + pass |
23 | 28 |
|
24 | 29 | from rectools import ExternalIds |
25 | 30 | from rectools.columns import Columns |
26 | 31 | from rectools.dataset import Dataset |
27 | | -from rectools.models import BERT4RecModel |
28 | | -from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
29 | | -from rectools.models.nn.transformers.base import ( |
30 | | - LearnableInversePositionalEncoding, |
31 | | - PreLNTransformerLayers, |
32 | | - TrainerCallable, |
33 | | - TransformerLightningModule, |
34 | | -) |
35 | | -from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
| 32 | + |
| 33 | +try: |
| 34 | + from rectools.models import BERT4RecModel |
| 35 | + from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
| 36 | + from rectools.models.nn.transformers.base import ( |
| 37 | + LearnableInversePositionalEncoding, |
| 38 | + PreLNTransformerLayers, |
| 39 | + TrainerCallable, |
| 40 | + TransformerLightningModule, |
| 41 | + ) |
| 42 | + from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
| 43 | +except ImportError: |
| 44 | + pass |
36 | 45 | from tests.models.data import DATASET |
37 | 46 | from tests.models.utils import ( |
38 | 47 | assert_default_config_and_default_model_params_are_the_same, |
|
41 | 50 |
|
42 | 51 | from .utils import custom_trainer, leave_one_out_mask |
43 | 52 |
|
| 53 | +pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
| 54 | + |
44 | 55 |
|
45 | 56 | class TestBERT4RecModel: |
46 | 57 | def setup_method(self) -> None: |
|
0 commit comments