-
Notifications
You must be signed in to change notification settings - Fork 18
Description
Background
We recently added support for the native device extension, which allows mapping TT-NN operations to PyTorch for eager execution on Tenstorrent hardware. A basic example is already in place, but to unlock real-world usage, we now want to support a complex transformer model — Bert Large.
Objective
Make Hugging Face’s transformers.BertModel.from_pretrained("bert-large-uncased")
run eagerly on the Tenstorrent device via:
model.to("ttnn:0")
This requires identifying all PyTorch operations used in the model's forward pass and mapping them to their TT-NN equivalents in the native extension.
Tasks
1. Identify Required Operations
- Analyze the forward pass of Bert Large to extract all used PyTorch operations.
- Add the full list of required ops as a checklist in this issue.
2. Map Operations to TT-NN
-
For each op:
- Check if a TT-NN equivalent exists.
- If not, implement it in the native extension (C++) or extend an existing one.
-
Ensure each op works correctly when invoked via
model.to("ttnn:0")
.
3. Handle Device Layouts
-
After moving to
ttnn:0
, tensors may use Row Major layout by default. -
Some TT-NN ops don’t yet support Row Major.
-
When this happens:
- ❗ Open an issue in the
[tt-metal](https://github.com/tenstorrent/tt-metal)
repository. - ✅ Optionally: Submit a PR to
tt-metal
adding layout conversion logic (e.g., Row Major → Tiled).
- ❗ Open an issue in the
-
-
Ensure layout compatibility is handled gracefully for each mapped op.
4. Ensure Model Correctness
- The Bert model must produce correct outputs.
- Accuracy is key — performance is not a priority at this stage.
Deliverables
-
A list of required PyTorch operations for Bert Large
-
Mapped TT-NN implementations for these operations (in native extension)
-
Layout handling logic (conversion or fallbacks as needed)
-
A test or validation script showing:
model = transformers.BertModel.from_pretrained("bert-large-uncased") model.to("ttnn:0") output = model(input_ids=..., attention_mask=...)
-
Documentation including:
- Mapped ops and their TT-NN equivalents
- Layout handling approach
- Any open issues or known limitations
Notes
- Start from Hugging Face’s
BertModel.from_pretrained("bert-large-uncased")
. - You can use simplified inputs for testing (
input_ids=torch.ones(...)
etc.). - Aim for functional parity — not speed or memory efficiency (yet).
- Improving layout abstraction or reusability is highly welcome.
- Build and test native device integration instructions
Who This Is For
This task is best suited for someone who:
- Is comfortable working in C++
- Has experience (or interest) in low-level integration of PyTorch and custom backends
- Wants to dive into how Tenstorrent’s native device integration works with PyTorch
Metadata
Metadata
Assignees
Type
Projects
Status
Status