This guide focuses on NxD inference models, decoder graphs, attention, and porting practices. For project-wide workflows see AGENTS.md.
- Context encoding graph: multi-token prompt → KV cache
- Token generation graph: 1 token in → 1 token out
- Speculation graph: optional multi-token proposal
Implementation details live in:
- optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py
- optimum/neuron/models/inference/backend/modules/decoder/decoder_builders.py
- optimum/neuron/models/inference/backend/modules/decoder/decoder_wrappers.py
KV cache is managed by KVCacheManager with BHSD layout and in-place aliasing:
Sampling on NeuronCores uses nxd_topk, nxd_argmax, NKI cumsum kernels:
- Runtime shapes must match compiled shapes.
- Call context encoding before token generation.
- TP degree must match compiled model.
- Decoder graph changes require cache prune:
python tools/prune_test_models.py.
- Sharding strategy selection: REPLICATE_TO_TP_DEGREE vs CONVERT_TO_MHA
- Logic in optimum/neuron/models/inference/backend/modules/attention/gqa.py
- Attention module guide (dispatch table, NKI kernel for head_dim > 128, sliding window): optimum/neuron/models/inference/backend/modules/attention/AGENTS.md
- Training path uses
attn_implementation="flash_attention_2".
Parallel QKV and output projections use ColumnParallelLinear/RowParallelLinear in:
- Replace HF
nn.Linear/Embeddingwith TP-aware parallel layers. - Replace HF attention with
NeuronAttentionBasefor static shapes. - Use
KVCacheManagerinstead of HF dynamic cache. - Optional fused QKV/MLP kernels (Neuron-only).
- State dict remaps (e.g., QKV concatenation).
See reference implementation:
- Expert routing and sharding are TP/EP aware.
- Expert capacity and dispatch are statically shaped.
- Expert MLPs use parallel layers or fused kernels.
- State dict remaps for expert sharding when required.
Porting from NxDI
Use NxDI for neuron-specific graph changes and HF Transformers for base architecture.
The Optimum Neuron implementation prioritizes stability, maintainability, and HF ecosystem compatibility over cutting-edge performance optimizations. For production deployments requiring maximum throughput, NxDI remains the reference implementation.
Track numerical differences using module-level tests before full graph tests:
- tests/decoder/test_modules.py compares HF layers to Neuron equivalents using
nxd_testing.build_module()andvalidate_accuracy(). - tests/decoder/test_attention.py validates attention with explicit rotary embedding and mask handling.
These isolate drift or state-dict conversion issues early.
When adding a new model directory:
- Create
CLAUDE.mdin the model directory containing@AGENTS.mdso Claude Code auto-loads the model-specific guide.