From 3c856eff940cc5bc5b336409b9d383742f3d920d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 25 Aug 2025 08:57:23 +0000 Subject: [PATCH 1/2] fix perf issue Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_perf.py | 11 +++++++---- transformer_engine/pytorch/module/base.py | 8 +++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index 2d4b62b23f..ad40c31c02 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -28,13 +28,15 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) model = torch.nn.Sequential( te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") ).cuda() - NUM_ITERS = 18000 + NUM_ITERS = 1800 elif layer == "transformer": model = torch.nn.Sequential( te.TransformerLayer(1, 1, 1, name="transformer1"), te.TransformerLayer(1, 1, 1, name="transformer2"), ).cuda() - NUM_ITERS = 2000 + NUM_ITERS = 200 + + NUM_INVOCATIONS_PER_ITER = 10 x = torch.randn(1, 1, 1).cuda() @@ -45,8 +47,9 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) time_start = time.time() for i in range(NUM_ITERS): - y = model(x) - y.sum().backward() + for _ in range(NUM_INVOCATIONS_PER_ITER): + y = model(x) + y.sum().backward() if debug_tools_initialized: debug_api.step() torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5d04b29f71..75afd65459 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1460,7 +1460,13 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_enabled_in_last_iteration = debug + else: + # If this is the same iteration as previous invocation of the module, + # we use the debug value from the first invocation in the iteration. + debug = self.debug_enabled_in_last_iteration + return debug def no_debug_features_active(self, quantizers): From 930ce73d62a8a7f3e7641318503280289490d8c9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 2 Sep 2025 14:05:52 +0000 Subject: [PATCH 2/2] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 75afd65459..1128a1a29e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1461,11 +1461,11 @@ def is_debug_iter(self) -> bool: else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_last_iteration = debug + self.debug_enabled_in_this_iteration = debug else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. - debug = self.debug_enabled_in_last_iteration + debug = self.debug_enabled_in_this_iteration return debug