@@ -308,6 +308,7 @@ def __init__(
308308 first_profile : int = 10 ,
309309 every_steps : Optional [int ] = None ,
310310 every_secs : Optional [float ] = 3600.0 ,
311+ on_steps : Optional [Iterable [int ]] = None ,
311312 artifact_name : str = "[{step}] Profile" ,
312313 ):
313314 """Initializes a new periodic profiler action.
@@ -322,12 +323,15 @@ def __init__(
322323 first_profile: First step at which a profile is started.
323324 every_steps: See `PeriodicAction.__init__()`.
324325 every_secs: See `PeriodicAction.__init__()`.
326+ on_steps: See `PeriodicAction.__init__()`.
325327 artifact_name: Name of the artifact to record.
326328 """
327329 if not num_profile_steps and not profile_duration_ms :
328330 raise ValueError (
329331 "Must specify num_profile_steps and/or profile_duration_ms." )
330- super ().__init__ (every_steps = every_steps , every_secs = every_secs )
332+ super ().__init__ (
333+ every_steps = every_steps , every_secs = every_secs , on_steps = on_steps
334+ )
331335 self ._num_profile_steps = num_profile_steps
332336 self ._first_profile = first_profile
333337 self ._profile_duration_ms = profile_duration_ms
@@ -383,7 +387,8 @@ def __init__(self,
383387 profile_duration_ms : int = 3_000 ,
384388 first_profile : int = 10 ,
385389 every_steps : Optional [int ] = None ,
386- every_secs : Optional [float ] = 3600.0 ):
390+ every_secs : Optional [float ] = 3600.0 ,
391+ on_steps : Optional [Iterable [int ]] = None ):
387392 """Initializes a new periodic profiler action.
388393
389394 Args:
@@ -394,8 +399,11 @@ def __init__(self,
394399 first_profile: First step at which a profile is started.
395400 every_steps: See `PeriodicAction.__init__()`.
396401 every_secs: See `PeriodicAction.__init__()`.
402+ on_steps: See `PeriodicAction.__init__()`.
397403 """
398- super ().__init__ (every_steps = every_steps , every_secs = every_secs )
404+ super ().__init__ (
405+ every_steps = every_steps , every_secs = every_secs , on_steps = on_steps
406+ )
399407 self ._hosts = hosts
400408 self ._first_profile = first_profile
401409 self ._profile_duration_ms = profile_duration_ms
0 commit comments