Skip to content

Commit 77fc88d

Browse files
committed
Skip more torch
1 parent 8b97751 commit 77fc88d

File tree

9 files changed

+169
-58
lines changed

9 files changed

+169
-58
lines changed

tests/models/nn/test_item_net.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,42 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617

1718
import numpy as np
1819
import pandas as pd
1920
import pytest
20-
import torch
21-
from pytorch_lightning import seed_everything
21+
22+
try:
23+
import torch
24+
from pytorch_lightning import seed_everything
25+
except ImportError:
26+
pass
2227

2328
from rectools.columns import Columns
2429
from rectools.dataset import Dataset
2530
from rectools.dataset.dataset import DatasetSchema, EntitySchema
26-
from rectools.models.nn.item_net import (
27-
CatFeaturesItemNet,
28-
IdEmbeddingsItemNet,
29-
ItemNetBase,
30-
ItemNetConstructorBase,
31-
SumOfEmbeddingsConstructor,
32-
)
31+
32+
try:
33+
from rectools.models.nn.item_net import (
34+
CatFeaturesItemNet,
35+
IdEmbeddingsItemNet,
36+
ItemNetBase,
37+
ItemNetConstructorBase,
38+
SumOfEmbeddingsConstructor,
39+
)
40+
except ImportError:
41+
CatFeaturesItemNet = object # type: ignore
42+
IdEmbeddingsItemNet = object # type: ignore
43+
ItemNetBase = object # type: ignore
44+
ItemNetConstructorBase = object # type: ignore
45+
SumOfEmbeddingsConstructor = object # type: ignore
3346

3447
from ..data import DATASET, INTERACTIONS
3548

49+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
50+
3651

3752
class TestIdEmbeddingsItemNet:
3853
def setup_method(self) -> None:

tests/models/nn/transformers/test_base.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,40 @@
1313
# limitations under the License.
1414

1515
import os
16+
import sys
1617
import typing as tp
1718
from tempfile import NamedTemporaryFile
1819

1920
import pandas as pd
2021
import pytest
21-
import torch
2222
from pytest import FixtureRequest
23-
from pytorch_lightning import Trainer, seed_everything
24-
from pytorch_lightning.loggers import CSVLogger
23+
24+
try:
25+
import torch
26+
from pytorch_lightning import Trainer, seed_everything
27+
from pytorch_lightning.loggers import CSVLogger
28+
29+
except ImportError:
30+
Trainer: tp.Any = object
2531

2632
from rectools import Columns
2733
from rectools.dataset import Dataset
28-
from rectools.models import BERT4RecModel, SASRecModel, load_model
29-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
30-
from rectools.models.nn.transformers.base import TransformerModelBase
34+
35+
try:
36+
from rectools.models import BERT4RecModel, SASRecModel, load_model
37+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
38+
from rectools.models.nn.transformers.base import TransformerModelBase
39+
except ImportError:
40+
TransformerModelBase: tp.Any = object
3141
from tests.models.data import INTERACTIONS
3242
from tests.models.utils import assert_save_load_do_not_change_model
3343

34-
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
44+
try:
45+
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
46+
except NameError:
47+
pass
48+
49+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
3550

3651

3752
class TestTransformerModelBase:

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,49 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617
from functools import partial
1718

1819
import numpy as np
1920
import pandas as pd
2021
import pytest
21-
import torch
22-
from pytorch_lightning import Trainer, seed_everything
22+
23+
try:
24+
import torch
25+
from pytorch_lightning import Trainer, seed_everything
26+
except ImportError:
27+
Trainer = object
2328

2429
from rectools import ExternalIds
2530
from rectools.columns import Columns
2631
from rectools.dataset import Dataset
27-
from rectools.models import BERT4RecModel
28-
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
29-
from rectools.models.nn.transformers.base import (
30-
LearnableInversePositionalEncoding,
31-
PreLNTransformerLayers,
32-
TrainerCallable,
33-
TransformerLightningModule,
34-
)
35-
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
32+
33+
try:
34+
from rectools.models import BERT4RecModel
35+
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
36+
from rectools.models.nn.transformers.base import (
37+
LearnableInversePositionalEncoding,
38+
PreLNTransformerLayers,
39+
TrainerCallable,
40+
TransformerLightningModule,
41+
)
42+
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
43+
except ImportError:
44+
TrainerCallable: tp.Any = object
3645
from tests.models.data import DATASET
3746
from tests.models.utils import (
3847
assert_default_config_and_default_model_params_are_the_same,
3948
assert_second_fit_refits_model,
4049
)
4150

42-
from .utils import custom_trainer, leave_one_out_mask
51+
try:
52+
from .utils import custom_trainer, leave_one_out_mask
53+
except NameError:
54+
pass
4355

4456

57+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
4558
class TestBERT4RecModel:
4659
def setup_method(self) -> None:
4760
self._seed_everything()
@@ -613,6 +626,7 @@ def _collate_fn_train(
613626
)
614627

615628

629+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
616630
class TestBERT4RecDataPreparator:
617631

618632
def setup_method(self) -> None:
@@ -792,6 +806,7 @@ def test_get_dataloader_val(
792806
assert torch.equal(value, val_batch[key])
793807

794808

809+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
795810
class TestBERT4RecModelConfiguration:
796811
def setup_method(self) -> None:
797812
self._seed_everything()

tests/models/nn/transformers/test_data_preparator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617

1718
import numpy as np
@@ -21,9 +22,15 @@
2122
from rectools.columns import Columns
2223
from rectools.dataset import Dataset, IdMap, Interactions
2324
from rectools.dataset.features import DenseFeatures
24-
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
25+
26+
try:
27+
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
28+
except ImportError:
29+
TransformerDataPreparatorBase: tp.Any = object # it's ok in case we're skipping the tests
2530
from tests.testing_utils import assert_feature_set_equal, assert_id_map_equal, assert_interactions_set_equal
2631

32+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
33+
2734

2835
class TestSequenceDataset:
2936

tests/models/nn/transformers/test_sasrec.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,49 @@
1414

1515
# pylint: disable=too-many-lines
1616

17+
import sys
1718
import typing as tp
1819
from functools import partial
1920

2021
import numpy as np
2122
import pandas as pd
2223
import pytest
23-
import torch
24-
from pytorch_lightning import Trainer, seed_everything
24+
25+
try:
26+
import torch
27+
from pytorch_lightning import Trainer, seed_everything
28+
except ImportError:
29+
Trainer: tp.Any = object
2530

2631
from rectools import ExternalIds
2732
from rectools.columns import Columns
2833
from rectools.dataset import Dataset, IdMap, Interactions
29-
from rectools.models import SASRecModel
30-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
31-
from rectools.models.nn.transformers.base import (
32-
LearnableInversePositionalEncoding,
33-
TrainerCallable,
34-
TransformerLightningModule,
35-
TransformerTorchBackbone,
36-
)
37-
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
34+
35+
try:
36+
from rectools.models import SASRecModel
37+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
38+
from rectools.models.nn.transformers.base import (
39+
LearnableInversePositionalEncoding,
40+
TrainerCallable,
41+
TransformerLightningModule,
42+
TransformerTorchBackbone,
43+
)
44+
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
45+
except ImportError:
46+
TrainerCallable: tp.Any = object
3847
from tests.models.data import DATASET
3948
from tests.models.utils import (
4049
assert_default_config_and_default_model_params_are_the_same,
4150
assert_second_fit_refits_model,
4251
)
4352
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal
4453

45-
from .utils import custom_trainer, leave_one_out_mask
46-
54+
try:
55+
from .utils import custom_trainer, leave_one_out_mask
56+
except NameError:
57+
pass
4758

59+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
4860
class TestSASRecModel:
4961
def setup_method(self) -> None:
5062
self._seed_everything()
@@ -698,6 +710,7 @@ def test_torch_model(self, dataset: Dataset) -> None:
698710
assert isinstance(model.torch_model, TransformerTorchBackbone)
699711

700712

713+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
701714
class TestSASRecDataPreparator:
702715

703716
def setup_method(self) -> None:
@@ -864,6 +877,7 @@ def test_get_dataloader_recommend(
864877
assert torch.equal(value, recommend_batch[key])
865878

866879

880+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
867881
class TestSASRecModelConfiguration:
868882
def setup_method(self) -> None:
869883
self._seed_everything()

tests/models/nn/transformers/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
1517
import pandas as pd
16-
from pytorch_lightning import Trainer
17-
from pytorch_lightning.callbacks import ModelCheckpoint
18+
import pytest
19+
20+
try:
21+
from pytorch_lightning import Trainer
22+
from pytorch_lightning.callbacks import ModelCheckpoint
23+
except ImportError:
24+
pass
1825

1926
from rectools import Columns
2027

28+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
29+
2130

2231
def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
2332
rank = (

tests/models/rank/test_rank.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,43 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617
from itertools import product
1718

1819
import numpy as np
1920
import pytest
20-
import torch
21+
22+
try:
23+
import torch
24+
except ImportError:
25+
pass
2126
from scipy import sparse
2227

23-
from rectools.models.rank import Distance, ImplicitRanker, Ranker, TorchRanker
28+
from rectools.models.rank import Distance, ImplicitRanker, Ranker
29+
30+
try:
31+
from rectools.models.rank import TorchRanker
32+
except ImportError:
33+
TorchRanker = object # it's ok in case we're skipping the tests
34+
2435

2536
T = tp.TypeVar("T")
2637
EPS_DIGITS = 5
2738
pytestmark = pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide")
2839

2940

3041
def gen_rankers() -> tp.List[tp.Tuple[tp.Any, tp.Dict[str, tp.Any]]]:
31-
torch_keys = ["device", "batch_size"]
32-
torch_vals = list(
33-
product(
34-
["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"],
35-
[128, 1],
42+
torch_ranker_args = []
43+
if not sys.version_info >= (3, 13):
44+
torch_keys = ["device", "batch_size"]
45+
torch_vals = list(
46+
product(
47+
["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"],
48+
[128, 1],
49+
)
3650
)
37-
)
38-
torch_ranker_args = [(TorchRanker, dict(zip(torch_keys, v))) for v in torch_vals]
51+
torch_ranker_args = [(TorchRanker, dict(zip(torch_keys, v))) for v in torch_vals]
3952

4053
implicit_keys = ["use_gpu"]
4154
implicit_vals = list(

0 commit comments

Comments
 (0)