Skip to content

Conversation

@jpdunc23
Copy link
Member

@jpdunc23 jpdunc23 commented Jan 23, 2026

Add TrainStepper which implements train_on_batch for training Stepper modules. This intermediate PR makes no changes to training configuration. In future PRs, we plan to remove all training-specific stepper config attributes from StepperConfig and use TrainStepperConfig at the top level of TrainConfig.

Changes:

  • Added StepperConfig.get_train_stepper_config, a temporary helper to create a TrainStepperConfig from the training-specific config attributes on StepperConfig.

  • Tests updated

Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR let's put back the existing StepperConfig and have it construct these new smaller configs, so that the coupled code doesn't have to change. Then you can remove it in the next PR updating the coupled code.

n_ensemble: The number of ensemble members evaluated for each training
batch member. Default is 2 if the loss type is EnsembleLoss, otherwise
the default is 1. Must be 2 for EnsembleLoss to be valid.
parameter_init: The parameter initialization configuration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameter_init is something we only do at the start of training, we should be able to move it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed and plan to handle in a future PR. I think this will require some significant refactors in fme.coupled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind this being in another PR, but doesn't it "just work" in fme.coupled? We're keeping the APIs the coupled code is exposed to (the StepperConfig for example) stable in this PR.

If we do this refactor after this PR, let's make sure we do it before we change the config yaml API.

)


class TrainStepper(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had TrainStepperConfig and TrainStepper in a separate file, but decided to put here to reduce the diff somewhat.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move them in a follow-on PR if you like.

@jpdunc23 jpdunc23 marked this pull request as ready for review January 26, 2026 16:47
Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please do a once-over of the refactored objects, and make sure any methods or attributes that aren't used are deleted or made private?

)


class TrainStepper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move them in a follow-on PR if you like.

Comment on lines +1406 to +1412
def set_eval(self) -> None:
for module in self.modules:
module.eval()

def set_train(self) -> None:
for module in self.modules:
module.train()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were previously implemented on TrainStepperABC. They are required by Stepper because CoupledStepper uses them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you removed these from the Stepper, CoupledStepper and TrainStepper would still work, because their existing implementations of set_train and set_eval (from the generic class) access the .modules attribute on Stepper (which is what's really needed) through their own implementations of .modules (which is what the TrainStepperABC code was accessing).

However that version of the code is also a little confusing (which for example led to the impression this wouldn't work, or is giving me the false impression it would work, if I'm just wrong). I'm not sure which is better (since this one does expose more API, which might be confusing for different reasons).

Copy link
Member Author

@jpdunc23 jpdunc23 Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this should work, but I personally prefer this more verbose version. Lmk if you prefer I bring back the TrainStepperABC method, in which case I will make it @final (previously it was overridden by CoupledStepper).

@jpdunc23
Copy link
Member Author

Could you please do a once-over of the refactored objects, and make sure any methods or attributes that aren't used are deleted or made private?

I've done so. As mentioned in #767, Stepper.predict could be made private but we should discuss at the next ACE technical sync.

n_ensemble: The number of ensemble members evaluated for each training
batch member. Default is 2 if the loss type is EnsembleLoss, otherwise
the default is 1. Must be 2 for EnsembleLoss to be valid.
parameter_init: The parameter initialization configuration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind this being in another PR, but doesn't it "just work" in fme.coupled? We're keeping the APIs the coupled code is exposed to (the StepperConfig for example) stable in this PR.

If we do this refactor after this PR, let's make sure we do it before we change the config yaml API.

return self._loss_normalizer

@property
def loss_obj(self) -> StepLoss:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: loss_obj is training-specific and should be on the training stepper - we don't need it for inference. The coupled code uses TrainStepper, so it should be agnostic to this choice.

Same with effective_loss_scaling.

If for some reason we really did want to keep this information with the inference stepper, it would go on the training history, since multiple losses can get used.

Copy link
Member Author

@jpdunc23 jpdunc23 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is planned for a future PR. The coupled code uses Stepper, not TrainStepper. This is intentional, since CoupledStepper doesn't use TrainStepper.train_on_batch. For the time being I'm leaving it on Stepper to avoid breaking the current CoupledStepper implementation.

Tasks I have in mind for future PR(s) are:

  1. Add CoupledTrainStepper, analogous to TrainStepper (started by you in Decouple coupled stepper training and inference #754)
  2. Update the coupled TrainConfig to support direct configuration of each component's loss and parameter_init, rather than relying on these attributes from StepperConfig, as CoupledStepper currently does.
  3. With this done, I plan to simultaneously remove loss_obj from Stepper and loss from StepperConfig.

There is a way to do 3 without first doing 1 and 2 by building each component's StepLoss directly from the loss: StepLossConfig attribute on StepperConfig, but I'd prefer to avoid this temporary workaround since it mostly won't survive in later stages of refactoring.

That said, task 2 will also involve building the component loss objects in the coupled stepper code, so if you feel strongly then I'm willing to give it a shot and hopefully some of the effort will be worthwhile later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is planned for a future PR. The coupled code uses Stepper, not TrainStepper.

I think this is the correct final state, but it’s not necessary for this PR and is getting in the way of making the changes needed in ace. The coupled code currently depends on the stepper in ace that supports training. I don’t think that should be changed in this PR. We should make that change after more of 1-3 is complete.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to scope this out but meeting a lot of resistance. I don't see a simple way to have CoupledStepper use TrainStepper without significant refactors in fme.coupled that I think are better left to the next PR.

  • While it's true that CoupledStepper was using a "training" Stepper before, TrainStepper is a very different object from what Stepper was. On the other hand, aside from no longer having a train_on_batch method, Stepper is basically unchanged so I think CoupledStepper should still use it for the time being.
  • To do what you want then CoupledStepperConfig will have to already use TrainStepperConfig in this PR. This is planned but involves significant refactoring in fme.coupled. I think we should wait until a later PR to do this refactoring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking at the code a bit I'm realizing this boils down to the choice to update the load_stepper function to return the more minimal inference-stepping class within this PR. I don't think at this point we should revert that change, so we can move forward with the way you had planned to do it.

You would have had more freedom to define the separation between the two classes in this PR if the load_stepper update weren't combined with it in the same PR (i.e. if load_stepper returned an object with the same API as it does in main), but we can move forward with the path you're currently on in this PR.

Comment on lines +1406 to +1412
def set_eval(self) -> None:
for module in self.modules:
module.eval()

def set_train(self) -> None:
for module in self.modules:
module.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you removed these from the Stepper, CoupledStepper and TrainStepper would still work, because their existing implementations of set_train and set_eval (from the generic class) access the .modules attribute on Stepper (which is what's really needed) through their own implementations of .modules (which is what the TrainStepperABC code was accessing).

However that version of the code is also a little confusing (which for example led to the impression this wouldn't work, or is giving me the false impression it would work, if I'm just wrong). I'm not sure which is better (since this one does expose more API, which might be confusing for different reasons).

Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jpdunc23 jpdunc23 enabled auto-merge (squash) February 2, 2026 23:47
@jpdunc23 jpdunc23 merged commit faabca9 into main Feb 3, 2026
7 checks passed
@jpdunc23 jpdunc23 deleted the refactor-stepper branch February 3, 2026 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants