Skip to content

[INITIALIZATION] prototype for functional intialization#410

Open
mayank31398 wants to merge 3 commits intomainfrom
norm
Open

[INITIALIZATION] prototype for functional intialization#410
mayank31398 wants to merge 3 commits intomainfrom
norm

Conversation

@mayank31398
Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for custom parameter initialization in the LayerNorm and RMSNorm classes by introducing an initialization_function argument and updating the reset_parameters method. The reviewer identified critical multiple inheritance issues in both classes where DTensorModule was not being properly initialized, which could lead to an AttributeError. Suggestions were provided to explicitly call the parent initializers in the correct order to ensure all attributes are correctly set before use.

Comment on lines +36 to 37
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a multiple inheritance issue here. LayerNorm inherits from nn.LayerNorm and DTensorModule, but DTensorModule.__init__ is never called. This will cause an AttributeError when self.is_tp_enabled is accessed later. You should explicitly call the initializers of both parent classes.

Since nn.LayerNorm.__init__ calls reset_parameters(), which you've modified to depend on self.initialization_function, the order of initialization is important. The DTensorModule should be initialized first.

Suggested change
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
DTensorModule.__init__(self)
self.initialization_function = initialization_function
nn.LayerNorm.__init__(self, normalized_shape, eps=eps)

Comment on lines +82 to 83
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to LayerNorm, this __init__ method has a multiple inheritance issue. RMSNorm inherits from nn.RMSNorm and DTensorModule, but DTensorModule.__init__ is never called. This will cause an AttributeError when self.is_tp_enabled is accessed. You should explicitly call the initializers of both parent classes.

Suggested change
self.initialization_function = initialization_function
super().__init__(normalized_shape, eps=eps)
DTensorModule.__init__(self)
self.initialization_function = initialization_function
nn.RMSNorm.__init__(self, normalized_shape, eps=eps)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant