Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cifar10/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.optim as optim
import utils
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler

import ignite
import ignite.distributed as idist
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fire
import torch
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision.models import wide_resnet50_2
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar10_qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.optim as optim
import utils
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler

import ignite
import ignite.distributed as idist
Expand Down
2 changes: 1 addition & 1 deletion examples/transformers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.optim as optim
import utils
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler

import ignite
import ignite.distributed as idist
Expand Down
12 changes: 6 additions & 6 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch

from torch.amp import GradScaler
import ignite.distributed as idist
from ignite.engine.deterministic import DeterministicEngine
from ignite.engine.engine import Engine
Expand Down Expand Up @@ -133,7 +133,7 @@ def supervised_training_step_amp(
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
scaler: Optional["torch.amp.GradScaler"] = None,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
Expand Down Expand Up @@ -393,8 +393,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to


def _check_arg(
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
on_tpu: bool, on_mps: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.amp.GradScaler"]]:
"""Checking tpu, mps, amp and GradScaler instance combinations."""
if on_mps and amp_mode:
raise ValueError("amp_mode cannot be used with mps device. Consider using amp_mode=None or device='cuda'.")
Expand All @@ -410,7 +410,7 @@ def _check_arg(
raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
elif amp_mode == "amp" and isinstance(scaler, bool):
try:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
except ImportError:
raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
scaler = GradScaler(enabled=True)
Expand All @@ -434,7 +434,7 @@ def create_supervised_trainer(
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
deterministic: bool = False,
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
scaler: Union[bool, "torch.amp.GradScaler"] = False,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "distributed: run distributed")
config.addinivalue_line("markers", "multinode_distributed: distributed")
config.addinivalue_line("markers", "tpu: run on tpu")
if config.option.treat_unrun_as_failed:
if getattr(config.option, "treat_unrun_as_failed", False):
unrun_tracker = UnrunTracker()
config.pluginmanager.register(unrun_tracker, "unrun_tracker_plugin")

Expand Down Expand Up @@ -611,6 +611,6 @@ def pytest_sessionfinish(session, exitstatus):
run finished, right before returning the exit status to the system.
"""
# If requested by the user, track all unrun tests and add them to the lastfailed cache
if session.config.option.treat_unrun_as_failed:
if getattr(session.config.option, "treat_unrun_as_failed", False):
unrun_tracker = session.config.pluginmanager.get_plugin("unrun_tracker_plugin")
unrun_tracker.record_unrun_as_failed(session, exitstatus)
14 changes: 7 additions & 7 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _default_create_supervised_trainer(
trainer_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
scaler: Union[bool, "torch.amp.GradScaler"] = False,
with_model_transform: bool = False,
with_model_fn: bool = False,
):
Expand Down Expand Up @@ -104,7 +104,7 @@ def _test_create_supervised_trainer(
trainer_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
scaler: Union[bool, "torch.amp.GradScaler"] = False,
with_model_transform: bool = False,
with_model_fn: bool = False,
):
Expand Down Expand Up @@ -170,18 +170,18 @@ def _():
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
def test_create_supervised_training_scalar_assignment():
with mock.patch("ignite.engine._check_arg") as check_arg_mock:
check_arg_mock.return_value = None, torch.cuda.amp.GradScaler(enabled=False)
check_arg_mock.return_value = None, torch.amp.GradScaler(enabled=False)
trainer, _ = _default_create_supervised_trainer(model_device="cpu", trainer_device="cpu", scaler=True)
assert hasattr(trainer.state, "scaler")
assert isinstance(trainer.state.scaler, torch.cuda.amp.GradScaler)
assert isinstance(trainer.state.scaler, torch.amp.GradScaler)


def _test_create_mocked_supervised_trainer(
model_device: Optional[str] = None,
trainer_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
scaler: Union[bool, "torch.amp.GradScaler"] = False,
):
with mock.patch("ignite.engine.supervised_training_step_amp") as training_step_amp_mock:
with mock.patch("ignite.engine.supervised_training_step_apex") as training_step_apex_mock:
Expand Down Expand Up @@ -462,7 +462,7 @@ def test_create_supervised_trainer_amp_error(mock_torch_cuda_amp_module):

@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
def test_create_supervised_trainer_scaler_not_amp():
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())

with pytest.raises(ValueError, match=f"scaler argument is {scaler}, but amp_mode is None."):
_test_create_supervised_trainer(amp_mode=None, scaler=scaler)
Expand Down Expand Up @@ -540,7 +540,7 @@ def test_create_supervised_trainer_on_cuda_amp_scaler():
_test_create_mocked_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True
)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
_test_create_supervised_trainer(
gradient_accumulation_steps=1,
model_device=model_device,
Expand Down
Loading