-
Notifications
You must be signed in to change notification settings - Fork 33
Add TrainStepper
#755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TrainStepper
#755
Conversation
mcgibbon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this PR let's put back the existing StepperConfig and have it construct these new smaller configs, so that the coupled code doesn't have to change. Then you can remove it in the next PR updating the coupled code.
| n_ensemble: The number of ensemble members evaluated for each training | ||
| batch member. Default is 2 if the loss type is EnsembleLoss, otherwise | ||
| the default is 1. Must be 2 for EnsembleLoss to be valid. | ||
| parameter_init: The parameter initialization configuration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parameter_init is something we only do at the start of training, we should be able to move it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed and plan to handle in a future PR. I think this will require some significant refactors in fme.coupled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind this being in another PR, but doesn't it "just work" in fme.coupled? We're keeping the APIs the coupled code is exposed to (the StepperConfig for example) stable in this PR.
If we do this refactor after this PR, let's make sure we do it before we change the config yaml API.
| ) | ||
|
|
||
|
|
||
| class TrainStepper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had TrainStepperConfig and TrainStepper in a separate file, but decided to put here to reduce the diff somewhat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could move them in a follow-on PR if you like.
mcgibbon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please do a once-over of the refactored objects, and make sure any methods or attributes that aren't used are deleted or made private?
| ) | ||
|
|
||
|
|
||
| class TrainStepper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could move them in a follow-on PR if you like.
| def set_eval(self) -> None: | ||
| for module in self.modules: | ||
| module.eval() | ||
|
|
||
| def set_train(self) -> None: | ||
| for module in self.modules: | ||
| module.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were previously implemented on TrainStepperABC. They are required by Stepper because CoupledStepper uses them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if you removed these from the Stepper, CoupledStepper and TrainStepper would still work, because their existing implementations of set_train and set_eval (from the generic class) access the .modules attribute on Stepper (which is what's really needed) through their own implementations of .modules (which is what the TrainStepperABC code was accessing).
However that version of the code is also a little confusing (which for example led to the impression this wouldn't work, or is giving me the false impression it would work, if I'm just wrong). I'm not sure which is better (since this one does expose more API, which might be confusing for different reasons).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this should work, but I personally prefer this more verbose version. Lmk if you prefer I bring back the TrainStepperABC method, in which case I will make it @final (previously it was overridden by CoupledStepper).
I've done so. As mentioned in #767, |
| n_ensemble: The number of ensemble members evaluated for each training | ||
| batch member. Default is 2 if the loss type is EnsembleLoss, otherwise | ||
| the default is 1. Must be 2 for EnsembleLoss to be valid. | ||
| parameter_init: The parameter initialization configuration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind this being in another PR, but doesn't it "just work" in fme.coupled? We're keeping the APIs the coupled code is exposed to (the StepperConfig for example) stable in this PR.
If we do this refactor after this PR, let's make sure we do it before we change the config yaml API.
| return self._loss_normalizer | ||
|
|
||
| @property | ||
| def loss_obj(self) -> StepLoss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: loss_obj is training-specific and should be on the training stepper - we don't need it for inference. The coupled code uses TrainStepper, so it should be agnostic to this choice.
Same with effective_loss_scaling.
If for some reason we really did want to keep this information with the inference stepper, it would go on the training history, since multiple losses can get used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is planned for a future PR. The coupled code uses Stepper, not TrainStepper. This is intentional, since CoupledStepper doesn't use TrainStepper.train_on_batch. For the time being I'm leaving it on Stepper to avoid breaking the current CoupledStepper implementation.
Tasks I have in mind for future PR(s) are:
- Add
CoupledTrainStepper, analogous toTrainStepper(started by you in Decouple coupled stepper training and inference #754) - Update the coupled
TrainConfigto support direct configuration of each component'slossandparameter_init, rather than relying on these attributes fromStepperConfig, asCoupledSteppercurrently does. - With this done, I plan to simultaneously remove
loss_objfromStepperandlossfromStepperConfig.
There is a way to do 3 without first doing 1 and 2 by building each component's StepLoss directly from the loss: StepLossConfig attribute on StepperConfig, but I'd prefer to avoid this temporary workaround since it mostly won't survive in later stages of refactoring.
That said, task 2 will also involve building the component loss objects in the coupled stepper code, so if you feel strongly then I'm willing to give it a shot and hopefully some of the effort will be worthwhile later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is planned for a future PR. The coupled code uses Stepper, not TrainStepper.
I think this is the correct final state, but it’s not necessary for this PR and is getting in the way of making the changes needed in ace. The coupled code currently depends on the stepper in ace that supports training. I don’t think that should be changed in this PR. We should make that change after more of 1-3 is complete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to scope this out but meeting a lot of resistance. I don't see a simple way to have CoupledStepper use TrainStepper without significant refactors in fme.coupled that I think are better left to the next PR.
- While it's true that CoupledStepper was using a "training" Stepper before, TrainStepper is a very different object from what Stepper was. On the other hand, aside from no longer having a
train_on_batchmethod, Stepper is basically unchanged so I think CoupledStepper should still use it for the time being. - To do what you want then CoupledStepperConfig will have to already use TrainStepperConfig in this PR. This is planned but involves significant refactoring in
fme.coupled. I think we should wait until a later PR to do this refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After looking at the code a bit I'm realizing this boils down to the choice to update the load_stepper function to return the more minimal inference-stepping class within this PR. I don't think at this point we should revert that change, so we can move forward with the way you had planned to do it.
You would have had more freedom to define the separation between the two classes in this PR if the load_stepper update weren't combined with it in the same PR (i.e. if load_stepper returned an object with the same API as it does in main), but we can move forward with the path you're currently on in this PR.
| def set_eval(self) -> None: | ||
| for module in self.modules: | ||
| module.eval() | ||
|
|
||
| def set_train(self) -> None: | ||
| for module in self.modules: | ||
| module.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if you removed these from the Stepper, CoupledStepper and TrainStepper would still work, because their existing implementations of set_train and set_eval (from the generic class) access the .modules attribute on Stepper (which is what's really needed) through their own implementations of .modules (which is what the TrainStepperABC code was accessing).
However that version of the code is also a little confusing (which for example led to the impression this wouldn't work, or is giving me the false impression it would work, if I'm just wrong). I'm not sure which is better (since this one does expose more API, which might be confusing for different reasons).
mcgibbon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Add
TrainStepperwhich implementstrain_on_batchfor trainingSteppermodules. This intermediate PR makes no changes to training configuration. In future PRs, we plan to remove all training-specific stepper config attributes fromStepperConfigand useTrainStepperConfigat the top level ofTrainConfig.Changes:
Added
StepperConfig.get_train_stepper_config, a temporary helper to create aTrainStepperConfigfrom the training-specific config attributes onStepperConfig.Tests updated