Skip to content

Fix A_log precision in mamba.py#1028

Closed
eyupcanakman wants to merge 1 commit intoml-explore:mainfrom
eyupcanakman:fix/mamba-alog-float32-565
Closed

Fix A_log precision in mamba.py#1028
eyupcanakman wants to merge 1 commit intoml-explore:mainfrom
eyupcanakman:fix/mamba-alog-float32-565

Conversation

@eyupcanakman
Copy link
Copy Markdown
Contributor

fixes #565.

mamba.py computes mx.exp(self.A_log) without casting to float32 first. When the model is loaded in bf16, the exponential loses precision and logprobs diverge from HuggingFace. mamba2.py, plamo2.py, and gated_delta.py all cast A_log to float32 at the usage site. Apply the same pattern here.

@angeloskath
Copy link
Copy Markdown
Member

A_log is actually stored in float32 so no need to cast. Is there a saved model that has an issue that you can point to? Given that #565 is also stale and potentially unrelated I will close this and if you encounter an issue file an issue and we can reopen this if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Numerical instability with BF16 models

2 participants