Skip to content

Commit 7c22ddf

Browse files
CLU Authorscopybara-github
authored andcommitted
Allow running profiler on specific steps.
PiperOrigin-RevId: 662014698
1 parent b64aa29 commit 7c22ddf

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

clu/periodic_actions.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def __init__(
308308
first_profile: int = 10,
309309
every_steps: Optional[int] = None,
310310
every_secs: Optional[float] = 3600.0,
311+
on_steps: Optional[Iterable[int]] = None,
311312
artifact_name: str = "[{step}] Profile",
312313
):
313314
"""Initializes a new periodic profiler action.
@@ -322,12 +323,15 @@ def __init__(
322323
first_profile: First step at which a profile is started.
323324
every_steps: See `PeriodicAction.__init__()`.
324325
every_secs: See `PeriodicAction.__init__()`.
326+
on_steps: See `PeriodicAction.__init__()`.
325327
artifact_name: Name of the artifact to record.
326328
"""
327329
if not num_profile_steps and not profile_duration_ms:
328330
raise ValueError(
329331
"Must specify num_profile_steps and/or profile_duration_ms.")
330-
super().__init__(every_steps=every_steps, every_secs=every_secs)
332+
super().__init__(
333+
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps
334+
)
331335
self._num_profile_steps = num_profile_steps
332336
self._first_profile = first_profile
333337
self._profile_duration_ms = profile_duration_ms
@@ -383,7 +387,8 @@ def __init__(self,
383387
profile_duration_ms: int = 3_000,
384388
first_profile: int = 10,
385389
every_steps: Optional[int] = None,
386-
every_secs: Optional[float] = 3600.0):
390+
every_secs: Optional[float] = 3600.0,
391+
on_steps: Optional[Iterable[int]] = None):
387392
"""Initializes a new periodic profiler action.
388393
389394
Args:
@@ -394,8 +399,11 @@ def __init__(self,
394399
first_profile: First step at which a profile is started.
395400
every_steps: See `PeriodicAction.__init__()`.
396401
every_secs: See `PeriodicAction.__init__()`.
402+
on_steps: See `PeriodicAction.__init__()`.
397403
"""
398-
super().__init__(every_steps=every_steps, every_secs=every_secs)
404+
super().__init__(
405+
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps
406+
)
399407
self._hosts = hosts
400408
self._first_profile = first_profile
401409
self._profile_duration_ms = profile_duration_ms

0 commit comments

Comments
 (0)