Skip to content

Skip prefix block TKG BHSD materialization#155

Open
dario-fumarola wants to merge 1 commit intoaws-neuron:mainfrom
dario-fumarola:perf/prefix-block-tkg-cache-path
Open

Skip prefix block TKG BHSD materialization#155
dario-fumarola wants to merge 1 commit intoaws-neuron:mainfrom
dario-fumarola:perf/prefix-block-tkg-cache-path

Conversation

@dario-fumarola
Copy link
Copy Markdown

@dario-fumarola dario-fumarola commented May 5, 2026

Description

Avoid redundant BHSD materialization for prefix-cached token generation when block KV layout and block TKG attention are enabled.

Today NeuronBaseModel.get_model_output() calls kv_mgr.get_cache() before entering the decoder layers. For prefix caching with BlockKVCacheManager, that gathers block KV and reshapes it into flat BHSD. The block TKG attention path then fetches raw block KV through kv_mgr._fetch_cache() and uses active_block_table, so the earlier BHSD materialization is not consumed.

This change skips that eager BHSD materialization only for the block-TKG token-generation path. Other paths keep the existing behavior.

Checklist

Core Change

  • Skip prefix-cache BHSD materialization only for token generation with block KV layout and block TKG attention
  • Preserve existing behavior for context encoding
  • Preserve existing behavior when block TKG attention is disabled
  • Preserve existing behavior when KV cache quantization is enabled
  • Preserve existing behavior for chunked/sliding-window attention
  • Preserve existing behavior for normal contiguous KV cache paths
  • Leave NKI kernels unchanged

Test Coverage

  • Fast-path regression test for prefix + block KV + block TKG
  • Fallback test for prefix + block KV without block TKG
  • KV-quant safety test to keep quantized KV on the existing materialization path
  • Adjacent block KV cache manager tests

Testing

How did you test this change?

Validated on inf2.xlarge using the Neuron SDK 2.29 inference DLC:

  • pytorch-inference-neuronx:2.9.0-sdk2.29.0-ubuntu24.04
  • torch 2.9.1
  • torch-neuronx 2.9.0.2.13
  • neuronx_distributed 0.18.27753

Test Results:

PYTHONPATH=src pytest -q test/unit/models/test_prefix_block_tkg_cache_path.py
# 3 passed

PYTHONPATH=src pytest -q test/unit/modules/kvcache/test_block_kv_cache_manager.py
# 19 passed

Also confirmed the patched branch is reachable during a real Neuron TKG trace on Inf2.

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29
  • Instance Type(s): inf2.xlarge
  • PyTorch Version: 2.9.1
  • torch-neuronx Version: 2.9.0.2.13
  • neuronx_distributed Version: 0.18.27753

Additional Information

The skip is gated on:

  • token generation, not context encoding
  • prefix caching enabled
  • block KV layout enabled
  • block TKG attention enabled
  • non-empty active_block_table
  • no KV cache quantization
  • no chunked/sliding-window attention
  • not the seq-id-mask cache-update subpath
  • BlockKVCacheManager

KV quantization remains on the existing path because prefix cache reads dequantize before block selection, while the raw block-TKG path does not currently receive dequant scale parameters.

Related Issues

N/A

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • The code follows the existing NxDI model/cache manager patterns
  • Relevant tests are included

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant