Skip to content

Add script for converting to a HF model#37

Merged
danbraunai merged 14 commits intodevfrom
feature/convert-to-hf
Aug 13, 2025
Merged

Add script for converting to a HF model#37
danbraunai merged 14 commits intodevfrom
feature/convert-to-hf

Conversation

@danbraunai
Copy link
Collaborator

@danbraunai danbraunai commented Aug 2, 2025

Description

TODO:

  • Script for uploading a trained model to HF. It should first convert to hf then upload. May require the user to have environment variables to authenticate with HF.

  • Adds convert_to_hf.py which contains convert_llama_model_to_hf for converting our custom Llama models to a HF LlamaForCausalLM model.

  • Tests that the above works for our canonical runs

Misc changes separate to main thrust of this PR:

  • Adds rms_norm_eps: float = 1e-6 to the config. Without this, we have to hardcode 1e-6 when converting to a HF model, which is dangerous.
  • Re-enable tests in CI
  • Fix all linting errors.

Related Issue

Closes #35

Motivation and Context

Our models are hosted on HF, but we don't have code to convert our LLama models to HF models.

How Has This Been Tested?

Added tests/test_hf_compatibility which tests that each of our canonical models can be converted and produce the same logits on the same inputs (with both the custom and hf tokenizers).

Does this PR introduce a breaking change?

No

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate to the main thrust of this PR. Not sure why this was commented out.


from simple_stories_train.utils import print0

# pyright: reportAttributeAccessIssue=false
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that we had a bunch of linter errors without things like this. I didn't fix them all "properly", not worth it atm.

) # Note that llama 3.1 n_key_value_heads does not scale with n_heads
use_grouped_query_attention: bool = True
flash_attention: bool = True
rms_norm_eps: float = 1e-6
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, we have to hardcode 1e-6 when converting to a HF model, which is dangerous.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate to the main thrust of this PR, sorry.


# Load the custom model
model_config = MODEL_CONFIGS[model_size]
custom_model = Llama.from_pretrained(f"SimpleStories/SimpleStories-{model_size}", model_config)
Copy link
Collaborator

@chandanms chandanms Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better test here is to initiate a model from Llama, convert it and then compare the outputs as Llama.from_pretrained(f"SimpleStories/SimpleStories-{model_size}", model_config) loads directly from HF. Something like

custom_model = Llama(model_config)
hf_model = convert_llama_model_to_hf(custom_model)

# compare the outputs

Copy link
Collaborator Author

@danbraunai danbraunai Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chandanms Oh good spot. The Llama.from_pretrained is actually broken. The line which loads in the weights is:

        model.load_state_dict(state_dict, strict=False)

The keys of the Llama model.state_dict() are completely different to the LlamaForCausalLM HF state_dict that's being downloaded, but strict=False hides everything. If this method was to actually load the LlamaForCausalLM weights into the Llama model, then there would need to be a function that's basically the inverse of the new convert_llama_model_to_hf created in this PR.

I think perhaps the Llama.from_pretrained() does work when it's actually loading some kind of Llama weights from HF, as opposed to LlamaForCausalLM weights? Though we don't seem to save any Llama weights on HF (which is reasonable).

I've added Issue #38 for this. This PR shouldn't be sorted until that's done. We should probably have both convert_llama_to_llama_for_causal_lm as well as convert_llama_for_causal_lm_to_llama, with tests for both of them. We should also avoid strict=False for all of these. (cc @lennart-finke, heads up for this bug)

I probably won't get to this tomorrow or maybe Tuesday. But if it's not fixed in a few days then I'll try to make sure to do it as it's a pretty pressing issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes. The converted HF model would have to be converted back into Llama. I think currently, Llama.from_pretrained can only load models that are locally trained. I will try to take this up tomorrow

@danbraunai danbraunai changed the base branch from main to dev August 13, 2025 16:32
@danbraunai danbraunai marked this pull request as ready for review August 13, 2025 16:32
@danbraunai danbraunai merged commit 3502203 into dev Aug 13, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Conversion Script to HF

3 participants