[INITIALIZATION] prototype for functional intialization#410
[INITIALIZATION] prototype for functional intialization#410mayank31398 wants to merge 3 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for custom parameter initialization in the LayerNorm and RMSNorm classes by introducing an initialization_function argument and updating the reset_parameters method. The reviewer identified critical multiple inheritance issues in both classes where DTensorModule was not being properly initialized, which could lead to an AttributeError. Suggestions were provided to explicitly call the parent initializers in the correct order to ensure all attributes are correctly set before use.
| self.initialization_function = initialization_function | ||
| super().__init__(normalized_shape, eps=eps) |
There was a problem hiding this comment.
There's a multiple inheritance issue here. LayerNorm inherits from nn.LayerNorm and DTensorModule, but DTensorModule.__init__ is never called. This will cause an AttributeError when self.is_tp_enabled is accessed later. You should explicitly call the initializers of both parent classes.
Since nn.LayerNorm.__init__ calls reset_parameters(), which you've modified to depend on self.initialization_function, the order of initialization is important. The DTensorModule should be initialized first.
| self.initialization_function = initialization_function | |
| super().__init__(normalized_shape, eps=eps) | |
| DTensorModule.__init__(self) | |
| self.initialization_function = initialization_function | |
| nn.LayerNorm.__init__(self, normalized_shape, eps=eps) |
| self.initialization_function = initialization_function | ||
| super().__init__(normalized_shape, eps=eps) |
There was a problem hiding this comment.
Similar to LayerNorm, this __init__ method has a multiple inheritance issue. RMSNorm inherits from nn.RMSNorm and DTensorModule, but DTensorModule.__init__ is never called. This will cause an AttributeError when self.is_tp_enabled is accessed. You should explicitly call the initializers of both parent classes.
| self.initialization_function = initialization_function | |
| super().__init__(normalized_shape, eps=eps) | |
| DTensorModule.__init__(self) | |
| self.initialization_function = initialization_function | |
| nn.RMSNorm.__init__(self, normalized_shape, eps=eps) |
No description provided.