Skip to content

Replace candle-core with burn-store in bert-burn#97

Merged
antimora merged 2 commits intotracel-ai:mainfrom
antimora:bert-burn-store-migration
Apr 17, 2026
Merged

Replace candle-core with burn-store in bert-burn#97
antimora merged 2 commits intotracel-ai:mainfrom
antimora:bert-burn-store-migration

Conversation

@antimora
Copy link
Copy Markdown
Collaborator

Summary

  • Drops candle-core from the workspace. bert-burn was the last direct user; burn-candle still pulls it transitively but nothing in this repo depends on it directly.
  • Rewrites bert-burn/src/loader.rs on top of burn-store, following the pattern already used by albert-burn and minilm-burn: SafetensorsStore + KeyRemapper regexes, with PyTorchToBurnAdapter handling LayerNorm weight/biasgamma/beta and Linear weight transposition. Loader shrinks from ~350 lines of per-module candle-to-Burn tensor shuffling to ~140 lines of declarative key mappings.
  • Adds BertEmbeddings::word_embeddings_weight() so the MLM loader can tie the decoder weight to the word embeddings after loading (mirroring HF's runtime behavior; RoBERTa checkpoints store lm_head.bias but not lm_head.decoder.weight).
  • Removes the candle-based from_safetensors methods from BertModel / BertMaskedLM / BertLMHead; callers now use the free functions load_pretrained and load_pretrained_masked_lm.
  • Switches HF download to the sync API and drops the tokio dep. Also removes the stale safetensors feature (the old code enabled it via candle-core/default but the loader never gated on it).

Test plan

  • cargo build --workspace
  • cargo run -p bert-burn --example infer-embedding --release produces sentence embeddings for RoBERTa-base (shape [3, 768])
  • cargo run -p bert-burn --example masked --release fills "Paris is the <mask> of France"capital (0.86), confirming the remap, adapter, and tied-weight step all work end-to-end

The bert-burn loader was the last user of candle-core in the workspace. It
used candle to parse safetensors files and then manually constructed each
Burn *Record<B> tensor-by-tensor. Replace this with burn-store, matching
the pattern used by albert-burn and minilm-burn: SafetensorsStore plus
KeyRemapper regex patterns, with PyTorchToBurnAdapter handling LayerNorm
weight/bias -> gamma/beta and Linear weight transposition.

- Drop candle-core from the workspace (only burn-candle pulls it transitively)
- Rewrite bert-burn/src/loader.rs with burn-store (350 -> ~140 lines)
- Expose BertEmbeddings::word_embeddings_weight() so the MLM loader can
  tie the decoder weight to the word embeddings (matches HF runtime)
- Remove candle-based from_safetensors methods from BertModel/BertMaskedLM
- Switch HF download to sync API and drop tokio dep
- Update both examples to call the new free functions
@antimora
Copy link
Copy Markdown
Collaborator Author

CC @laggui

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR removes the last direct candle-core usage from bert-burn by rewriting safetensors loading to use burn-store, aligning BERT/RoBERTa loading with the pattern already used in other model crates in the workspace.

Changes:

  • Replaced the candle-based safetensors loader with a burn-store (SafetensorsStore + KeyRemapper + PyTorchToBurnAdapter) implementation.
  • Removed from_safetensors record constructors from BERT model types and updated examples to use load_pretrained / load_pretrained_masked_lm.
  • Dropped candle-core, tokio, and the stale safetensors feature from bert-burn dependencies; updated docs accordingly.

Reviewed changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
bert-burn/src/model.rs Removes candle-based from_safetensors APIs from model types.
bert-burn/src/loader.rs New burn-store-based safetensors loading + sync HF download.
bert-burn/src/embedding.rs Adds accessor needed for MLM decoder weight tying.
bert-burn/examples/masked.rs Updates example to new loader API (load_pretrained_masked_lm).
bert-burn/examples/infer-embedding.rs Updates example to new loader API (load_pretrained).
bert-burn/README.md Updates usage instructions to burn-flex + burn-store approach.
bert-burn/Cargo.toml Removes candle-core/tokio, adds burn-store dependency.
Cargo.toml Removes workspace-level direct candle-core dependency entry.
Cargo.lock Locks updated dependency graph reflecting the removal.
Comments suppressed due to low confidence (1)

bert-burn/src/loader.rs:152

  • download_hf_model still panics on errors and returns a tuple, but this module now defines LoadError::{Download, Config} and the other loaders (e.g., minilm/albert) use Result-based download + config loading. Consider switching download_hf_model (and ideally load_model_config) to return Result<_, LoadError> so callers can handle failures without panicking, or drop the unused error variants if panics are the intended API.
/// Download model config and weights from Hugging Face Hub.
/// Cached files are reused.
pub fn download_hf_model(model_name: &str) -> (PathBuf, PathBuf) {
    let api = hf_hub::api::sync::Api::new().expect("Failed to create HF API client");
    let repo = api.model(model_name.to_string());

    let model_filepath = repo.get("model.safetensors").unwrap_or_else(|_| {
        panic!(
            "Failed to download: {} weights with name: model.safetensors from HuggingFace Hub",
            model_name
        )
    });

    let config_filepath = repo.get("config.json").unwrap_or_else(|_| {
        panic!(
            "Failed to download: {} config with name: config.json from HuggingFace Hub",
            model_name
        )
    });

    (config_filepath, model_filepath)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread bert-burn/src/loader.rs
let data = TensorData::new(data, Shape::new(array));

Tensor::<B, 1>::from_floats(data, &device.clone())
use burn_store::{KeyRemapper, ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleSnapshot is imported but never used in this module. Please remove the import or add an explicit #[allow(unused_imports)] if it’s intentionally kept for future debugging, to avoid unused-import warnings.

Suggested change
use burn_store::{KeyRemapper, ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
use burn_store::{KeyRemapper, PyTorchToBurnAdapter, SafetensorsStore};

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping ModuleSnapshot — it's the trait that provides load_from, which the module calls on the next line (model.load_from(&mut store)). Removing the import breaks the build. Matches the pattern in albert-burn and minilm-burn which both import it for the same reason.

Comment thread bert-burn/src/loader.rs
Comment on lines +80 to +87
let mut store = SafetensorsStore::from_file(checkpoint_path)
.with_from_adapter(PyTorchToBurnAdapter)
.remap(remapper)
.allow_partial(true);

model
.load_from(&mut store)
.map_err(|e| LoadError::Store(e.to_string()))?;
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_partial(true) on the store can silently keep randomly-initialized parameters if a key mapping is wrong or a tensor is missing, which can make loading failures hard to detect. For BertModel (embeddings/encoder/(optional)pooler), consider removing allow_partial(true) and letting load_from fail on missing tensors, or explicitly validating that only truly-optional parameters are skipped.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in f344120. Dropped allow_partial(true) from load_pretrained (BertModel) so missing tensors now error with a specific path, instead of silently falling back to random init. Kept it on load_pretrained_masked_lm because the MLM decoder weight is intentionally not in the HF checkpoint (tied to word_embeddings.weight) and we supply it manually post-load; added a comment explaining why.

- Remove allow_partial(true) from load_pretrained so missing model tensors
  error out instead of silently keeping random weights. Keep it on
  load_pretrained_masked_lm where lm_head.decoder.weight is intentionally
  supplied post-load via weight tying.
- Convert download_hf_model and load_model_config to return
  Result<_, LoadError> instead of panicking, matching the albert-burn /
  minilm-burn patterns. Update examples accordingly.
@antimora
Copy link
Copy Markdown
Collaborator Author

Addressed the suppressed Copilot note about download_hf_model panicking: converted download_hf_model and load_model_config to return Result<_, LoadError> in f344120, matching the albert-burn / minilm-burn pattern. Examples updated to propagate the errors (still .expect(...) at the example boundary since that's what the other model examples do).

Re-ran both examples against roberta-base post-change — sentence embeddings produce shape [3, 768] and the MLM example still predicts capital (0.86) for "Paris is the of France", confirming the weight tying path still works without the loader's previous allow_partial(true).

@antimora antimora merged commit cb0d97e into tracel-ai:main Apr 17, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants