feat: optional CUDA/GPU acceleration for embeddings#66
Open
sattva1 wants to merge 3 commits intoharshkedia177:mainfrom
Open
feat: optional CUDA/GPU acceleration for embeddings#66sattva1 wants to merge 3 commits intoharshkedia177:mainfrom
sattva1 wants to merge 3 commits intoharshkedia177:mainfrom
Conversation
added 3 commits
April 3, 2026 13:15
…ng pipeline - Add configure_cuda() / validate_cuda() public API and _resolve_cuda() to honour both the --cuda flag and AXON_CUDA env var - Key model cache on (model_name, cuda) tuple to prevent CPU/GPU model aliasing; pass cuda=True to TextEmbedding and surface ONNX CUDAExecutionProvider fallback warnings as RuntimeError - Expose --cuda flag on analyze, watch, host, and serve commands via shared _configure_and_validate_cuda() helper
…t OOM fastembed defaults to Device.AUTO, which auto-detects and uses CUDA when onnxruntime-gpu is installed. On GPUs with limited VRAM (e.g. 8GB), the nomic model with batch_size=32 causes OOM via a 9.5GB BiasSoftmax allocation. Pass cuda=False explicitly in the CPU path. Also fix test isolation: reset _cuda_enabled and AXON_CUDA env var in the autouse fixture to prevent state leaking between tests.
The nomic-embed-text-v1.5 model has 12 attention heads and 2048-token context, making each batch element's attention matrix ~192 MB. At batch_size=32 this totals ~6.4 GB for attention alone, causing OOM on both CPU (physical memory) and GPUs with <= 8 GB VRAM. Batch size 8 keeps peak memory under ~2 GB. This was not an issue with the previous BAAI/bge-small-en-v1.5 model (6 heads, 512-token limit) but the batch size was never adjusted when the model was upgraded.
Author
|
Two additional fixes pushed: 1. Explicit fastembed defaults to 2. Batch size 32 → 8 ( The default batch size was never adjusted when the model changed from |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #64.
Adds opt-in CUDA support for the embedding pipeline, reducing embedding time from minutes to seconds on GPU-capable machines.
Two activation paths:
--cudaCLI flag onanalyze,watch,host, andserveAXON_CUDA=1environment variable (works across all commands without per-command flags)Design
Uses a module-level configuration in
embedder.pyrather than threading acudaparameter through every function signature._get_model()reads the CUDA state at call time via_resolve_cuda(), which checks both the programmatic flag and the env var. This means all embedding call sites — pipeline, watcher, and search-timeembed_query()from MCP/web — automatically use GPU when enabled.CUDA validation
When CUDA is requested,
_get_model()captures fastembed'sRuntimeWarningonCUDAExecutionProviderfallback and converts it to aRuntimeErrorwith actionable install instructions. CLI commands callvalidate_cuda()before the pipeline starts, so the error surfaces immediately rather than being swallowed by the pipeline's broadexcept Exceptionhandler.Changes
src/axon/core/embeddings/embedder.py—configure_cuda(),_resolve_cuda(),validate_cuda()._get_model()uses(model_name, cuda)compound cache key and post-init CUDA fallback detection.src/axon/cli/main.py—_configure_and_validate_cuda()helper.--cudaflag onanalyze,watch,host,serve.tests/core/test_embedder.py— 15 new tests covering flag, env var, cache separation, fallback detection.Usage