Add script for converting to a HF model#37
Conversation
There was a problem hiding this comment.
Separate to the main thrust of this PR. Not sure why this was commented out.
simple_stories_train/models/llama.py
Outdated
|
|
||
| from simple_stories_train.utils import print0 | ||
|
|
||
| # pyright: reportAttributeAccessIssue=false |
There was a problem hiding this comment.
I noticed that we had a bunch of linter errors without things like this. I didn't fix them all "properly", not worth it atm.
| ) # Note that llama 3.1 n_key_value_heads does not scale with n_heads | ||
| use_grouped_query_attention: bool = True | ||
| flash_attention: bool = True | ||
| rms_norm_eps: float = 1e-6 |
There was a problem hiding this comment.
Without this, we have to hardcode 1e-6 when converting to a HF model, which is dangerous.
There was a problem hiding this comment.
Separate to the main thrust of this PR, sorry.
tests/test_hf_compatibility.py
Outdated
|
|
||
| # Load the custom model | ||
| model_config = MODEL_CONFIGS[model_size] | ||
| custom_model = Llama.from_pretrained(f"SimpleStories/SimpleStories-{model_size}", model_config) |
There was a problem hiding this comment.
I think a better test here is to initiate a model from Llama, convert it and then compare the outputs as Llama.from_pretrained(f"SimpleStories/SimpleStories-{model_size}", model_config) loads directly from HF. Something like
custom_model = Llama(model_config)
hf_model = convert_llama_model_to_hf(custom_model)
# compare the outputs
There was a problem hiding this comment.
@chandanms Oh good spot. The Llama.from_pretrained is actually broken. The line which loads in the weights is:
model.load_state_dict(state_dict, strict=False)
The keys of the Llama model.state_dict() are completely different to the LlamaForCausalLM HF state_dict that's being downloaded, but strict=False hides everything. If this method was to actually load the LlamaForCausalLM weights into the Llama model, then there would need to be a function that's basically the inverse of the new convert_llama_model_to_hf created in this PR.
I think perhaps the Llama.from_pretrained() does work when it's actually loading some kind of Llama weights from HF, as opposed to LlamaForCausalLM weights? Though we don't seem to save any Llama weights on HF (which is reasonable).
I've added Issue #38 for this. This PR shouldn't be sorted until that's done. We should probably have both convert_llama_to_llama_for_causal_lm as well as convert_llama_for_causal_lm_to_llama, with tests for both of them. We should also avoid strict=False for all of these. (cc @lennart-finke, heads up for this bug)
I probably won't get to this tomorrow or maybe Tuesday. But if it's not fixed in a few days then I'll try to make sure to do it as it's a pretty pressing issue.
There was a problem hiding this comment.
Ah yes. The converted HF model would have to be converted back into Llama. I think currently, Llama.from_pretrained can only load models that are locally trained. I will try to take this up tomorrow
This reverts commit 2c9aedb. Wrong commit with pytest!
…on the backward compatibility
Added conversion scripts and corresponding tests
Description
TODO:
Script for uploading a trained model to HF. It should first convert to hf then upload. May require the user to have environment variables to authenticate with HF.
Adds convert_to_hf.py which contains
convert_llama_model_to_hffor converting our custom Llama models to a HF LlamaForCausalLM model.Tests that the above works for our canonical runs
Misc changes separate to main thrust of this PR:
rms_norm_eps: float = 1e-6to the config. Without this, we have to hardcode 1e-6 when converting to a HF model, which is dangerous.Related Issue
Closes #35
Motivation and Context
Our models are hosted on HF, but we don't have code to convert our LLama models to HF models.
How Has This Been Tested?
Added tests/test_hf_compatibility which tests that each of our canonical models can be converted and produce the same logits on the same inputs (with both the custom and hf tokenizers).
Does this PR introduce a breaking change?
No