Skip to content

Latest commit

 

History

History
78 lines (57 loc) · 4.12 KB

File metadata and controls

78 lines (57 loc) · 4.12 KB

Optimum Neuron Inference Models (NxD) Guide

This guide focuses on NxD inference models, decoder graphs, attention, and porting practices. For project-wide workflows see AGENTS.md.

NxD Decoder Models

Three-Graph Architecture

  • Context encoding graph: multi-token prompt → KV cache
  • Token generation graph: 1 token in → 1 token out
  • Speculation graph: optional multi-token proposal

Implementation details live in:

KV Cache Management

KV cache is managed by KVCacheManager with BHSD layout and in-place aliasing:

On-Device Sampling

Sampling on NeuronCores uses nxd_topk, nxd_argmax, NKI cumsum kernels:

Common Pitfalls

  • Runtime shapes must match compiled shapes.
  • Call context encoding before token generation.
  • TP degree must match compiled model.
  • Decoder graph changes require cache prune: python tools/prune_test_models.py.

Attention Mechanisms

Grouped Query Attention (GQA)

Flash Attention on Neuron

Parallel Attention Layers

Parallel QKV and output projections use ColumnParallelLinear/RowParallelLinear in:

Neuron vs HF Modeling Differences

Llama-like models (decoder-only)

  • Replace HF nn.Linear/Embedding with TP-aware parallel layers.
  • Replace HF attention with NeuronAttentionBase for static shapes.
  • Use KVCacheManager instead of HF dynamic cache.
  • Optional fused QKV/MLP kernels (Neuron-only).
  • State dict remaps (e.g., QKV concatenation).

See reference implementation:

MoE models (Mixtral, Qwen3 MoE)

  • Expert routing and sharding are TP/EP aware.
  • Expert capacity and dispatch are statically shaped.
  • Expert MLPs use parallel layers or fused kernels.
  • State dict remaps for expert sharding when required.

Porting from NxDI

Use NxDI for neuron-specific graph changes and HF Transformers for base architecture.

The Optimum Neuron implementation prioritizes stability, maintainability, and HF ecosystem compatibility over cutting-edge performance optimizations. For production deployments requiring maximum throughput, NxDI remains the reference implementation.

Per-Module Parity Tests

Track numerical differences using module-level tests before full graph tests:

These isolate drift or state-dict conversion issues early.

New Model Checklist

When adding a new model directory:

  • Create CLAUDE.md in the model directory containing @AGENTS.md so Claude Code auto-loads the model-specific guide.