Replace candle-core with burn-store in bert-burn#97
Conversation
The bert-burn loader was the last user of candle-core in the workspace. It used candle to parse safetensors files and then manually constructed each Burn *Record<B> tensor-by-tensor. Replace this with burn-store, matching the pattern used by albert-burn and minilm-burn: SafetensorsStore plus KeyRemapper regex patterns, with PyTorchToBurnAdapter handling LayerNorm weight/bias -> gamma/beta and Linear weight transposition. - Drop candle-core from the workspace (only burn-candle pulls it transitively) - Rewrite bert-burn/src/loader.rs with burn-store (350 -> ~140 lines) - Expose BertEmbeddings::word_embeddings_weight() so the MLM loader can tie the decoder weight to the word embeddings (matches HF runtime) - Remove candle-based from_safetensors methods from BertModel/BertMaskedLM - Switch HF download to sync API and drop tokio dep - Update both examples to call the new free functions
|
CC @laggui |
There was a problem hiding this comment.
Pull request overview
This PR removes the last direct candle-core usage from bert-burn by rewriting safetensors loading to use burn-store, aligning BERT/RoBERTa loading with the pattern already used in other model crates in the workspace.
Changes:
- Replaced the candle-based safetensors loader with a
burn-store(SafetensorsStore+KeyRemapper+PyTorchToBurnAdapter) implementation. - Removed
from_safetensorsrecord constructors from BERT model types and updated examples to useload_pretrained/load_pretrained_masked_lm. - Dropped
candle-core,tokio, and the stalesafetensorsfeature frombert-burndependencies; updated docs accordingly.
Reviewed changes
Copilot reviewed 8 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
bert-burn/src/model.rs |
Removes candle-based from_safetensors APIs from model types. |
bert-burn/src/loader.rs |
New burn-store-based safetensors loading + sync HF download. |
bert-burn/src/embedding.rs |
Adds accessor needed for MLM decoder weight tying. |
bert-burn/examples/masked.rs |
Updates example to new loader API (load_pretrained_masked_lm). |
bert-burn/examples/infer-embedding.rs |
Updates example to new loader API (load_pretrained). |
bert-burn/README.md |
Updates usage instructions to burn-flex + burn-store approach. |
bert-burn/Cargo.toml |
Removes candle-core/tokio, adds burn-store dependency. |
Cargo.toml |
Removes workspace-level direct candle-core dependency entry. |
Cargo.lock |
Locks updated dependency graph reflecting the removal. |
Comments suppressed due to low confidence (1)
bert-burn/src/loader.rs:152
download_hf_modelstill panics on errors and returns a tuple, but this module now definesLoadError::{Download, Config}and the other loaders (e.g., minilm/albert) useResult-based download + config loading. Consider switchingdownload_hf_model(and ideallyload_model_config) to returnResult<_, LoadError>so callers can handle failures without panicking, or drop the unused error variants if panics are the intended API.
/// Download model config and weights from Hugging Face Hub.
/// Cached files are reused.
pub fn download_hf_model(model_name: &str) -> (PathBuf, PathBuf) {
let api = hf_hub::api::sync::Api::new().expect("Failed to create HF API client");
let repo = api.model(model_name.to_string());
let model_filepath = repo.get("model.safetensors").unwrap_or_else(|_| {
panic!(
"Failed to download: {} weights with name: model.safetensors from HuggingFace Hub",
model_name
)
});
let config_filepath = repo.get("config.json").unwrap_or_else(|_| {
panic!(
"Failed to download: {} config with name: config.json from HuggingFace Hub",
model_name
)
});
(config_filepath, model_filepath)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| let data = TensorData::new(data, Shape::new(array)); | ||
|
|
||
| Tensor::<B, 1>::from_floats(data, &device.clone()) | ||
| use burn_store::{KeyRemapper, ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; |
There was a problem hiding this comment.
ModuleSnapshot is imported but never used in this module. Please remove the import or add an explicit #[allow(unused_imports)] if it’s intentionally kept for future debugging, to avoid unused-import warnings.
| use burn_store::{KeyRemapper, ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore}; | |
| use burn_store::{KeyRemapper, PyTorchToBurnAdapter, SafetensorsStore}; |
There was a problem hiding this comment.
Keeping ModuleSnapshot — it's the trait that provides load_from, which the module calls on the next line (model.load_from(&mut store)). Removing the import breaks the build. Matches the pattern in albert-burn and minilm-burn which both import it for the same reason.
| let mut store = SafetensorsStore::from_file(checkpoint_path) | ||
| .with_from_adapter(PyTorchToBurnAdapter) | ||
| .remap(remapper) | ||
| .allow_partial(true); | ||
|
|
||
| model | ||
| .load_from(&mut store) | ||
| .map_err(|e| LoadError::Store(e.to_string()))?; |
There was a problem hiding this comment.
allow_partial(true) on the store can silently keep randomly-initialized parameters if a key mapping is wrong or a tensor is missing, which can make loading failures hard to detect. For BertModel (embeddings/encoder/(optional)pooler), consider removing allow_partial(true) and letting load_from fail on missing tensors, or explicitly validating that only truly-optional parameters are skipped.
There was a problem hiding this comment.
Good catch — fixed in f344120. Dropped allow_partial(true) from load_pretrained (BertModel) so missing tensors now error with a specific path, instead of silently falling back to random init. Kept it on load_pretrained_masked_lm because the MLM decoder weight is intentionally not in the HF checkpoint (tied to word_embeddings.weight) and we supply it manually post-load; added a comment explaining why.
- Remove allow_partial(true) from load_pretrained so missing model tensors error out instead of silently keeping random weights. Keep it on load_pretrained_masked_lm where lm_head.decoder.weight is intentionally supplied post-load via weight tying. - Convert download_hf_model and load_model_config to return Result<_, LoadError> instead of panicking, matching the albert-burn / minilm-burn patterns. Update examples accordingly.
|
Addressed the suppressed Copilot note about Re-ran both examples against |
Summary
candle-corefrom the workspace.bert-burnwas the last direct user;burn-candlestill pulls it transitively but nothing in this repo depends on it directly.bert-burn/src/loader.rson top ofburn-store, following the pattern already used byalbert-burnandminilm-burn:SafetensorsStore+KeyRemapperregexes, withPyTorchToBurnAdapterhandling LayerNormweight/bias→gamma/betaand Linear weight transposition. Loader shrinks from ~350 lines of per-module candle-to-Burn tensor shuffling to ~140 lines of declarative key mappings.BertEmbeddings::word_embeddings_weight()so the MLM loader can tie the decoder weight to the word embeddings after loading (mirroring HF's runtime behavior; RoBERTa checkpoints storelm_head.biasbut notlm_head.decoder.weight).from_safetensorsmethods fromBertModel/BertMaskedLM/BertLMHead; callers now use the free functionsload_pretrainedandload_pretrained_masked_lm.tokiodep. Also removes the stalesafetensorsfeature (the old code enabled it viacandle-core/defaultbut the loader never gated on it).Test plan
cargo build --workspacecargo run -p bert-burn --example infer-embedding --releaseproduces sentence embeddings for RoBERTa-base (shape[3, 768])cargo run -p bert-burn --example masked --releasefills"Paris is the <mask> of France"→capital(0.86), confirming the remap, adapter, and tied-weight step all work end-to-end