Skip to content
This repository was archived by the owner on Feb 21, 2026. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions rtdetr_pose/tests/test_train_minimal_grad_accum_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import importlib.util
import unittest
from pathlib import Path

try:
import torch
except ImportError: # pragma: no cover
torch = None


def _load_train_minimal_module():
repo_root = Path(__file__).resolve().parents[2]
script_path = repo_root / "rtdetr_pose" / "tools" / "train_minimal.py"
spec = importlib.util.spec_from_file_location("rtdetr_pose_tools_train_minimal", script_path)
assert spec is not None and spec.loader is not None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


@unittest.skipIf(torch is None, "torch not installed")
class TestTrainMinimalGradAccumAMP(unittest.TestCase):
def test_gradient_accumulation_argument(self):
"""Test that gradient accumulation steps argument is parsed correctly."""
mod = _load_train_minimal_module()

# Test default value
args = mod.parse_args([])
self.assertEqual(args.gradient_accumulation_steps, 1)

# Test custom value
args = mod.parse_args(["--gradient-accumulation-steps", "4"])
self.assertEqual(args.gradient_accumulation_steps, 4)

def test_amp_argument(self):
"""Test that AMP argument is parsed correctly."""
mod = _load_train_minimal_module()

# Test default value (False)
args = mod.parse_args([])
self.assertFalse(args.use_amp)

# Test with flag enabled
args = mod.parse_args(["--use-amp"])
self.assertTrue(args.use_amp)

def test_clip_grad_norm_exists(self):
"""Test that gradient clipping argument exists (already implemented)."""
mod = _load_train_minimal_module()

# Test default value
args = mod.parse_args([])
self.assertEqual(args.clip_grad_norm, 0.0)

# Test custom value
args = mod.parse_args(["--clip-grad-norm", "1.0"])
self.assertEqual(args.clip_grad_norm, 1.0)


if __name__ == "__main__":
unittest.main()
120 changes: 120 additions & 0 deletions rtdetr_pose/tests/test_train_minimal_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import importlib.util
import json
import tempfile
import unittest
from pathlib import Path

try:
import torch
except ImportError: # pragma: no cover
torch = None


def _load_train_minimal_module():
repo_root = Path(__file__).resolve().parents[2]
script_path = repo_root / "rtdetr_pose" / "tools" / "train_minimal.py"
spec = importlib.util.spec_from_file_location("rtdetr_pose_tools_train_minimal", script_path)
assert spec is not None and spec.loader is not None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


@unittest.skipIf(torch is None, "torch not installed")
class TestTrainMinimalIntegration(unittest.TestCase):
def setUp(self):
self.repo_root = Path(__file__).resolve().parents[2]
self.data_dir = self.repo_root / "data" / "coco128"
if not self.data_dir.is_dir():
self.data_dir = self.repo_root.parent / "data" / "coco128"

def test_gradient_accumulation_integration(self):
"""Test that training works with gradient accumulation."""
if not self.data_dir.is_dir():
self.skipTest("coco128 missing; run: bash tools/fetch_coco128.sh")

mod = _load_train_minimal_module()

with tempfile.TemporaryDirectory() as tmpdir:
args = [
"--dataset-root", str(self.data_dir),
"--split", "train2017",
"--epochs", "1",
"--batch-size", "2",
"--max-steps", "3",
"--image-size", "64",
"--device", "cpu",
"--gradient-accumulation-steps", "2",
"--metrics-jsonl", str(Path(tmpdir) / "metrics.jsonl"),
"--no-export-onnx",
]

result = mod.main(args)
self.assertEqual(result, 0, "Training should complete successfully")

# Check that metrics file exists
metrics_file = Path(tmpdir) / "metrics.jsonl"
self.assertTrue(metrics_file.exists(), "Metrics file should be created")

# Verify metrics were written
with open(metrics_file) as f:
lines = f.readlines()
self.assertGreater(len(lines), 0, "Metrics should be logged")

def test_amp_on_cpu_warning(self):
"""Test that AMP on CPU device shows warning."""
if not self.data_dir.is_dir():
self.skipTest("coco128 missing; run: bash tools/fetch_coco128.sh")

mod = _load_train_minimal_module()

with tempfile.TemporaryDirectory() as tmpdir:
args = [
"--dataset-root", str(self.data_dir),
"--split", "train2017",
"--epochs", "1",
"--batch-size", "2",
"--max-steps", "2",
"--image-size", "64",
"--device", "cpu",
"--use-amp",
"--metrics-jsonl", str(Path(tmpdir) / "metrics.jsonl"),
"--no-export-onnx",
]

# This should complete but print a warning about AMP requiring CUDA
result = mod.main(args)
self.assertEqual(result, 0, "Training should complete successfully even with AMP on CPU")

def test_combined_features(self):
"""Test that gradient clipping, accumulation work together."""
if not self.data_dir.is_dir():
self.skipTest("coco128 missing; run: bash tools/fetch_coco128.sh")

mod = _load_train_minimal_module()

with tempfile.TemporaryDirectory() as tmpdir:
args = [
"--dataset-root", str(self.data_dir),
"--split", "train2017",
"--epochs", "1",
"--batch-size", "2",
"--max-steps", "4",
"--image-size", "64",
"--device", "cpu",
"--clip-grad-norm", "1.0",
"--gradient-accumulation-steps", "2",
"--metrics-jsonl", str(Path(tmpdir) / "metrics.jsonl"),
"--no-export-onnx",
]

result = mod.main(args)
self.assertEqual(result, 0, "Training should complete successfully with combined features")

# Check that metrics file exists
metrics_file = Path(tmpdir) / "metrics.jsonl"
self.assertTrue(metrics_file.exists(), "Metrics file should be created")


if __name__ == "__main__":
unittest.main()
72 changes: 64 additions & 8 deletions rtdetr_pose/tools/train_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def build_parser() -> argparse.ArgumentParser:
default=0.0,
help="If >0, clip gradients to this max norm before optimizer step.",
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
default=1,
help="Number of steps to accumulate gradients before optimizer update (default: 1, no accumulation).",
)
parser.add_argument(
"--use-amp",
action="store_true",
help="Enable Automatic Mixed Precision (AMP) training with torch.cuda.amp.",
)
parser.add_argument(
"--lr-warmup-steps",
type=int,
Expand Down Expand Up @@ -1697,6 +1708,15 @@ def main(argv: list[str] | None = None) -> int:
weight_decay=float(args.weight_decay),
)

# Initialize GradScaler for AMP if enabled
scaler = None
if args.use_amp:
if device.startswith("cuda"):
scaler = torch.cuda.amp.GradScaler()
print("amp_enabled=True device=cuda")
else:
print("amp_warning: --use-amp requires CUDA device; AMP disabled")

start_epoch = 0
global_step = 0
if args.resume_from:
Expand Down Expand Up @@ -1755,7 +1775,14 @@ def main(argv: list[str] | None = None) -> int:
mim_ratio = float(targets["mim_mask_ratio"].mean().detach().cpu())
except Exception:
mim_ratio = None
out = model(images)

# Forward pass with optional AMP autocast
if scaler is not None:
with torch.cuda.amp.autocast():
out = model(images)
else:
out = model(images)

mim_loss = None
if args.mim_teacher and float(mim_weight) > 0 and isinstance(targets, dict):
image_raw = targets.get("image_raw")
Expand All @@ -1766,7 +1793,11 @@ def main(argv: list[str] | None = None) -> int:
if was_training:
model.eval()
with torch.no_grad():
teacher_out = model(image_raw.to(device))
if scaler is not None:
with torch.cuda.amp.autocast():
teacher_out = model(image_raw.to(device))
else:
teacher_out = model(image_raw.to(device))
if was_training:
model.train()
loss_items = []
Expand Down Expand Up @@ -1873,11 +1904,36 @@ def main(argv: list[str] | None = None) -> int:
}
print("loss_breakdown", " ".join(f"{k}={v:.6g}" for k, v in sorted(printable.items())))

optim.zero_grad(set_to_none=True)
loss.backward()
if args.clip_grad_norm and float(args.clip_grad_norm) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.clip_grad_norm))
optim.step()
# Store unscaled loss for logging
loss_for_logging = loss.detach().cpu()

# Gradient accumulation: scale loss by accumulation steps
accum_steps = int(args.gradient_accumulation_steps)
if accum_steps > 1:
loss = loss / accum_steps

# Backward pass with optional AMP scaling
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()

# Perform optimizer step only at accumulation boundaries
# steps is 0-indexed within each epoch, so we use (steps + 1) for the check
if (steps + 1) % accum_steps == 0:
if scaler is not None:
# Unscale gradients before clipping
if args.clip_grad_norm and float(args.clip_grad_norm) > 0:
scaler.unscale_(optim)
torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.clip_grad_norm))
scaler.step(optim)
scaler.update()
optim.zero_grad(set_to_none=True)
else:
if args.clip_grad_norm and float(args.clip_grad_norm) > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.clip_grad_norm))
optim.step()
optim.zero_grad(set_to_none=True)

if args.lr_warmup_steps and int(args.lr_warmup_steps) > 0:
lr_now = compute_warmup_lr(
Expand All @@ -1897,7 +1953,7 @@ def main(argv: list[str] | None = None) -> int:
for group in optim.param_groups:
group["lr"] = lr_now

running += float(loss.detach().cpu())
running += float(loss_for_logging)
steps += 1
global_step += 1

Expand Down
7 changes: 7 additions & 0 deletions train_setting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ metrics_csv: reports/train_metrics.csv
tensorboard_logdir: reports/tb
export_onnx: true
onnx_out: reports/rtdetr_pose.onnx
# Gradient clipping (already implemented)
# clip_grad_norm: 1.0
# Gradient accumulation (new feature)
# gradient_accumulation_steps: 1
# AMP (Automatic Mixed Precision) - requires CUDA device (new feature)
# use_amp: false

Loading