A grab-bag of experimental transformer kernels and utilities (mostly PyTorch + Triton).
transformer_nuggets/flash: Triton FlashAttention experiments + masking/bias utilities.transformer_nuggets/quant: NF4 tensor subclass + QLoRA building blocks (pure PyTorch).transformer_nuggets/fp8: FP8 casting / scaled-quantization kernels (Triton).transformer_nuggets/cute: CUTE DSL experiments and tooling (includes an intra-kernel profiler).transformer_nuggets/misc: Odds and ends (e.g. attention wrappers, utilities).transformer_nuggets/llama: LLaMA-ish model + training/finetune scripts (research-grade).
This repository is research code: APIs are not stable and may change.
You’ll need a working PyTorch install first (CPU or CUDA). Follow the official PyTorch install instructions.
To install from PyPI:
pip install transformer_nuggetsTo hack on the code locally:
git clone https://github.com/drisspg/transformer_nuggets.git
cd transformer_nuggets
pip install -e .Optional extras:
pip install "transformer_nuggets[flash]" # triton
pip install "transformer_nuggets[qlora]" # bitsandbytes (optional comparisons)
pip install "transformer_nuggets[llama]" # llama training utilitiesNF4 tensor round-trip:
import torch
from transformer_nuggets.quant.nf4_tensor import NF4Tensor
w = torch.randn(4096, 4096, dtype=torch.float16).contiguous()
w.requires_grad_(False)
nf4_w = NF4Tensor.from_tensor(w)
w_dequant = nf4_w.get_original_weight()FlashAttention (requires CUDA + Triton; API is experimental):
import math
import torch
from transformer_nuggets.flash import BiasMode, attention
q = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 1024, 64, device="cuda", dtype=torch.float16)
o, _ = attention(q, k, v, causal=True, sm_scale=1 / math.sqrt(q.size(-1)), bias_choice=BiasMode.none)CUTE intra-kernel profiling (writes a Perfetto trace):
python -m transformer_nuggets.cute.profiler.exampletransformer_nuggets/: Python package.benchmarks/: Microbenchmarks and profiling scripts.examples/: Small runnable examples.scripts/: One-off utilities.test/: PyTest suite.
pip install -e ".[dev]"
pre-commit install
pytest