Skip to content

Commit b3030c7

Browse files
committed
Skip more torch
1 parent 8b97751 commit b3030c7

File tree

9 files changed

+139
-47
lines changed

9 files changed

+139
-47
lines changed

tests/models/nn/test_item_net.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,38 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617

1718
import numpy as np
1819
import pandas as pd
1920
import pytest
20-
import torch
21-
from pytorch_lightning import seed_everything
21+
22+
try:
23+
import torch
24+
from pytorch_lightning import seed_everything
25+
except ImportError:
26+
pass
2227

2328
from rectools.columns import Columns
2429
from rectools.dataset import Dataset
2530
from rectools.dataset.dataset import DatasetSchema, EntitySchema
26-
from rectools.models.nn.item_net import (
27-
CatFeaturesItemNet,
28-
IdEmbeddingsItemNet,
29-
ItemNetBase,
30-
ItemNetConstructorBase,
31-
SumOfEmbeddingsConstructor,
32-
)
31+
32+
try:
33+
from rectools.models.nn.item_net import (
34+
CatFeaturesItemNet,
35+
IdEmbeddingsItemNet,
36+
ItemNetBase,
37+
ItemNetConstructorBase,
38+
SumOfEmbeddingsConstructor,
39+
)
40+
except ImportError:
41+
pass
3342

3443
from ..data import DATASET, INTERACTIONS
3544

45+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
46+
3647

3748
class TestIdEmbeddingsItemNet:
3849
def setup_method(self) -> None:

tests/models/nn/transformers/test_base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,38 @@
1313
# limitations under the License.
1414

1515
import os
16+
import sys
1617
import typing as tp
1718
from tempfile import NamedTemporaryFile
1819

1920
import pandas as pd
2021
import pytest
21-
import torch
2222
from pytest import FixtureRequest
23-
from pytorch_lightning import Trainer, seed_everything
24-
from pytorch_lightning.loggers import CSVLogger
23+
24+
try:
25+
import torch
26+
from pytorch_lightning import Trainer, seed_everything
27+
from pytorch_lightning.loggers import CSVLogger
28+
29+
except ImportError:
30+
pass
2531

2632
from rectools import Columns
2733
from rectools.dataset import Dataset
28-
from rectools.models import BERT4RecModel, SASRecModel, load_model
29-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
30-
from rectools.models.nn.transformers.base import TransformerModelBase
34+
35+
try:
36+
from rectools.models import BERT4RecModel, SASRecModel, load_model
37+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
38+
from rectools.models.nn.transformers.base import TransformerModelBase
39+
except ImportError:
40+
pass
3141
from tests.models.data import INTERACTIONS
3242
from tests.models.utils import assert_save_load_do_not_change_model
3343

3444
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
3545

46+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
47+
3648

3749
class TestTransformerModelBase:
3850
def setup_method(self) -> None:

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617
from functools import partial
1718

1819
import numpy as np
1920
import pandas as pd
2021
import pytest
21-
import torch
22-
from pytorch_lightning import Trainer, seed_everything
22+
23+
try:
24+
import torch
25+
from pytorch_lightning import Trainer, seed_everything
26+
except ImportError:
27+
pass
2328

2429
from rectools import ExternalIds
2530
from rectools.columns import Columns
2631
from rectools.dataset import Dataset
27-
from rectools.models import BERT4RecModel
28-
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
29-
from rectools.models.nn.transformers.base import (
30-
LearnableInversePositionalEncoding,
31-
PreLNTransformerLayers,
32-
TrainerCallable,
33-
TransformerLightningModule,
34-
)
35-
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
32+
33+
try:
34+
from rectools.models import BERT4RecModel
35+
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
36+
from rectools.models.nn.transformers.base import (
37+
LearnableInversePositionalEncoding,
38+
PreLNTransformerLayers,
39+
TrainerCallable,
40+
TransformerLightningModule,
41+
)
42+
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
43+
except ImportError:
44+
pass
3645
from tests.models.data import DATASET
3746
from tests.models.utils import (
3847
assert_default_config_and_default_model_params_are_the_same,
@@ -41,6 +50,8 @@
4150

4251
from .utils import custom_trainer, leave_one_out_mask
4352

53+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
54+
4455

4556
class TestBERT4RecModel:
4657
def setup_method(self) -> None:

tests/models/nn/transformers/test_data_preparator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617

1718
import numpy as np
@@ -21,9 +22,15 @@
2122
from rectools.columns import Columns
2223
from rectools.dataset import Dataset, IdMap, Interactions
2324
from rectools.dataset.features import DenseFeatures
24-
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
25+
26+
try:
27+
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
28+
except ImportError:
29+
pass
2530
from tests.testing_utils import assert_feature_set_equal, assert_id_map_equal, assert_interactions_set_equal
2631

32+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
33+
2734

2835
class TestSequenceDataset:
2936

tests/models/nn/transformers/test_sasrec.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,36 @@
1414

1515
# pylint: disable=too-many-lines
1616

17+
import sys
1718
import typing as tp
1819
from functools import partial
1920

2021
import numpy as np
2122
import pandas as pd
2223
import pytest
23-
import torch
24-
from pytorch_lightning import Trainer, seed_everything
24+
25+
try:
26+
import torch
27+
from pytorch_lightning import Trainer, seed_everything
28+
except ImportError:
29+
pass
2530

2631
from rectools import ExternalIds
2732
from rectools.columns import Columns
2833
from rectools.dataset import Dataset, IdMap, Interactions
29-
from rectools.models import SASRecModel
30-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
31-
from rectools.models.nn.transformers.base import (
32-
LearnableInversePositionalEncoding,
33-
TrainerCallable,
34-
TransformerLightningModule,
35-
TransformerTorchBackbone,
36-
)
37-
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
34+
35+
try:
36+
from rectools.models import SASRecModel
37+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
38+
from rectools.models.nn.transformers.base import (
39+
LearnableInversePositionalEncoding,
40+
TrainerCallable,
41+
TransformerLightningModule,
42+
TransformerTorchBackbone,
43+
)
44+
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
45+
except ImportError:
46+
pass
3847
from tests.models.data import DATASET
3948
from tests.models.utils import (
4049
assert_default_config_and_default_model_params_are_the_same,
@@ -44,6 +53,8 @@
4453

4554
from .utils import custom_trainer, leave_one_out_mask
4655

56+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
57+
4758

4859
class TestSASRecModel:
4960
def setup_method(self) -> None:

tests/models/nn/transformers/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
1517
import pandas as pd
16-
from pytorch_lightning import Trainer
17-
from pytorch_lightning.callbacks import ModelCheckpoint
18+
import pytest
19+
20+
try:
21+
from pytorch_lightning import Trainer
22+
from pytorch_lightning.callbacks import ModelCheckpoint
23+
except ImportError:
24+
pass
1825

1926
from rectools import Columns
2027

28+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
29+
2130

2231
def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
2332
rank = (

tests/models/rank/test_rank.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617
from itertools import product
1718

1819
import numpy as np
1920
import pytest
20-
import torch
21+
22+
try:
23+
import torch
24+
except ImportError:
25+
pass
2126
from scipy import sparse
2227

23-
from rectools.models.rank import Distance, ImplicitRanker, Ranker, TorchRanker
28+
from rectools.models.rank import Distance, ImplicitRanker, Ranker
29+
30+
try:
31+
from rectools.models.rank import TorchRanker
32+
except ImportError:
33+
TorchRanker = object # it's ok in case we're skipping the tests
34+
2435

2536
T = tp.TypeVar("T")
2637
EPS_DIGITS = 5
@@ -45,7 +56,10 @@ def gen_rankers() -> tp.List[tp.Tuple[tp.Any, tp.Dict[str, tp.Any]]]:
4556
)
4657
implicit_ranker_args = [(ImplicitRanker, dict(zip(implicit_keys, v))) for v in implicit_vals]
4758

48-
return [*torch_ranker_args, *implicit_ranker_args]
59+
if sys.version_info >= (3, 13):
60+
return implicit_ranker_args
61+
else:
62+
return [*torch_ranker_args, *implicit_ranker_args]
4963

5064

5165
class TestRanker: # pylint: disable=protected-access

tests/models/rank/test_rank_torch.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617
from itertools import product
1718

1819
import numpy as np
1920
import pytest
20-
import torch
21+
22+
try:
23+
import torch
24+
except ImportError:
25+
pass
2126
from scipy import sparse
2227

23-
from rectools.models.rank import Distance, Ranker, TorchRanker
28+
try:
29+
from rectools.models.rank import Distance, Ranker, TorchRanker
30+
except ImportError:
31+
pass
2432

2533
T = tp.TypeVar("T")
2634
EPS_DIGITS = 5
2735
pytestmark = pytest.mark.filterwarnings("ignore:invalid value encountered in true_divide")
36+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
2837

2938

3039
def gen_rankers() -> tp.List[tp.Tuple[tp.Any, tp.Dict[str, tp.Any]]]:

tests/models/test_serialization.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@
2929
LightFM = object # it's ok in case we're skipping the tests
3030

3131
from rectools.metrics import NDCG
32+
33+
try:
34+
from rectools.models import DSSMModel
35+
except ImportError:
36+
DMMSModel = object # it's ok in case we're skipping the tests
3237
from rectools.models import (
33-
DSSMModel,
3438
EASEModel,
3539
ImplicitALSWrapperModel,
3640
ImplicitBPRWrapperModel,
@@ -44,7 +48,11 @@
4448
serialization,
4549
)
4650
from rectools.models.base import ModelBase, ModelConfig
47-
from rectools.models.nn.transformers.base import TransformerModelBase
51+
52+
try:
53+
from rectools.models.nn.transformers.base import TransformerModelBase
54+
except ImportError:
55+
TransformerModelBase = object
4856
from rectools.models.vector import VectorModel
4957
from rectools.utils.config import BaseConfig
5058

0 commit comments

Comments
 (0)