Skip to content

Commit d596436

Browse files
committed
Support PyTorch 2.6
1 parent 0381da0 commit d596436

File tree

13 files changed

+136
-137
lines changed

13 files changed

+136
-137
lines changed

poetry.lock

Lines changed: 90 additions & 33 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ nmslib-metabrainz = {version = "^2.1.3", python = ">=3.11, <3.13", optional = tr
8686
torch = [
8787
{version = ">=1.6.0, <2.3.0", python = "<3.13", markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", optional = true},
8888
{version = ">=1.6.0, <3.0.0", python = "<3.13", optional = true},
89+
{version = ">=2.6.0, <3.0.0", python = ">=3.13", optional = true},
8990
]
9091
pytorch-lightning = {version = ">=1.6.0, <3.0.0", python = "<3.13", optional = true}
9192

tests/dataset/test_torch_dataset.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,15 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=attribute-defined-outside-init,consider-using-enumerate
16-
import sys
17-
1816
import numpy as np
1917
import pandas as pd
2018
import pytest
21-
22-
try:
23-
import torch
24-
except ImportError:
25-
pass
26-
19+
import torch
2720
from scipy import sparse
2821

2922
from rectools.columns import Columns
3023
from rectools.dataset import Dataset
31-
32-
try:
33-
from rectools.dataset.torch_datasets import DSSMItemDataset, DSSMTrainDataset, DSSMUserDataset
34-
except ModuleNotFoundError:
35-
pass
36-
37-
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
24+
from rectools.dataset.torch_datasets import DSSMItemDataset, DSSMTrainDataset, DSSMUserDataset
3825

3926

4027
class WithFixtures:

tests/models/nn/test_dssm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def filter_warnings_decorator(func): # type: ignore
5050

5151
from ..data import INTERACTIONS
5252

53-
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
53+
pytestmark = pytest.mark.skipif(
54+
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
55+
)
5456

5557

5658
@filter_warnings_decorator

tests/models/nn/test_item_net.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,29 @@
1818
import numpy as np
1919
import pandas as pd
2020
import pytest
21+
import torch
2122

2223
try:
23-
import torch
2424
from pytorch_lightning import seed_everything
2525
except ImportError:
2626
pass
2727

2828
from rectools.columns import Columns
2929
from rectools.dataset import Dataset
3030
from rectools.dataset.dataset import DatasetSchema, EntitySchema
31-
32-
try:
33-
from rectools.models.nn.item_net import (
34-
CatFeaturesItemNet,
35-
IdEmbeddingsItemNet,
36-
ItemNetBase,
37-
ItemNetConstructorBase,
38-
SumOfEmbeddingsConstructor,
39-
)
40-
except ImportError:
41-
CatFeaturesItemNet = object # type: ignore
42-
IdEmbeddingsItemNet = object # type: ignore
43-
ItemNetBase = object # type: ignore
44-
ItemNetConstructorBase = object # type: ignore
45-
SumOfEmbeddingsConstructor = object # type: ignore
31+
from rectools.models.nn.item_net import (
32+
CatFeaturesItemNet,
33+
IdEmbeddingsItemNet,
34+
ItemNetBase,
35+
ItemNetConstructorBase,
36+
SumOfEmbeddingsConstructor,
37+
)
4638

4739
from ..data import DATASET, INTERACTIONS
4840

49-
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
41+
pytestmark = pytest.mark.skipif(
42+
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
43+
)
5044

5145

5246
class TestIdEmbeddingsItemNet:

tests/models/nn/transformers/test_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
import pandas as pd
2121
import pytest
22+
import torch
2223
from pytest import FixtureRequest
2324

2425
try:
25-
import torch
2626
from pytorch_lightning import Trainer, seed_everything
2727
from pytorch_lightning.loggers import CSVLogger
2828

@@ -46,7 +46,9 @@
4646
except NameError:
4747
pass
4848

49-
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
49+
pytestmark = pytest.mark.skipif(
50+
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
51+
)
5052

5153

5254
class TestTransformerModelBase:

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,19 @@
1313
# limitations under the License.
1414

1515
import sys
16-
import types
1716
import typing as tp
1817
from functools import partial
1918

2019
import numpy as np
2120
import pandas as pd
2221
import pytest
22+
import torch
2323

2424
try:
25-
import torch
2625
from pytorch_lightning import Trainer, seed_everything
2726
except ImportError:
28-
torch = types.ModuleType("torch")
29-
torch.Tensor = object # type: ignore
30-
torch.float = object # type: ignore
3127
Trainer = object # type: ignore
3228

33-
def tensor(*args: tp.Any, **kwargs: tp.Any) -> tp.Any:
34-
return object()
35-
36-
torch.tensor = tensor
37-
3829
from rectools import ExternalIds
3930
from rectools.columns import Columns
4031
from rectools.dataset import Dataset
@@ -64,7 +55,7 @@ def tensor(*args: tp.Any, **kwargs: tp.Any) -> tp.Any:
6455
pass
6556

6657

67-
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
58+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13")
6859
class TestBERT4RecModel:
6960
def setup_method(self) -> None:
7061
self._seed_everything()
@@ -642,7 +633,7 @@ def _collate_fn_train(
642633
)
643634

644635

645-
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
636+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13")
646637
class TestBERT4RecDataPreparator:
647638

648639
def setup_method(self) -> None:
@@ -822,7 +813,7 @@ def test_get_dataloader_val(
822813
assert torch.equal(value, val_batch[key])
823814

824815

825-
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
816+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13")
826817
class TestBERT4RecModelConfiguration:
827818
def setup_method(self) -> None:
828819
self._seed_everything()

tests/models/nn/transformers/test_data_preparator.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import sys
1615
import typing as tp
1716

1817
import numpy as np
@@ -22,15 +21,9 @@
2221
from rectools.columns import Columns
2322
from rectools.dataset import Dataset, IdMap, Interactions
2423
from rectools.dataset.features import DenseFeatures
25-
26-
try:
27-
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
28-
except ImportError:
29-
TransformerDataPreparatorBase = object # type: ignore
24+
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
3025
from tests.testing_utils import assert_feature_set_equal, assert_id_map_equal, assert_interactions_set_equal
3126

32-
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
33-
3427

3528
class TestSequenceDataset:
3629

tests/models/nn/transformers/test_sasrec.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,17 @@
1515
# pylint: disable=too-many-lines
1616

1717
import sys
18-
import types
1918
import typing as tp
2019
from functools import partial
2120

2221
import numpy as np
2322
import pandas as pd
2423
import pytest
24+
import torch
2525

2626
try:
27-
import torch
2827
from pytorch_lightning import Trainer, seed_everything
2928
except ImportError:
30-
torch = types.ModuleType("torch")
31-
torch.tensor = lambda x: None # type: ignore
32-
torch.Tensor = object # type: ignore
3329
Trainer = object # type: ignore
3430

3531
from rectools import ExternalIds
@@ -61,7 +57,9 @@
6157
except NameError:
6258
pass
6359

64-
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
60+
pytestmark = pytest.mark.skipif(
61+
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
62+
)
6563

6664

6765
class TestSASRecModel:

0 commit comments

Comments
 (0)