Skip to content

Commit 6f1debc

Browse files
metascroymalfet
authored andcommitted
various fixes to make et_export and et_wrapper work
1 parent 2f29dc3 commit 6f1debc

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ For a list of devices, see below, under *SUPPORTED SYSTEMS*
66

77
A goal of this repo, and the design of the PT2 components was to offer seamless integration and consistent workflows.
88
Both mobile and server/desktop paths start with torch.export() receiving the same model description. Similarly,
9-
integration into runners for Python (for initial testing) and Python-free environments (for deployment, in runner-posix
10-
and runner-mobile, respectively) offer very consistent experiences across backends and offer developers consistent interfaces
9+
integration into runners for Python (for initial testing) and Python-free environments (for deployment, in runner-aoti
10+
and runner-et, respectively) offer a consistent experience across backends and offer developers consistent interfaces
1111
and user experience whether they target server, desktop or mobile & edge use cases, and/or all of them.
1212

1313

@@ -85,12 +85,14 @@ The environment variable MODEL_REPO should point to a directory with the `model.
8585
The command below will add the file "llama-fast.pte" to your MODEL_REPO directory.
8686

8787
```
88-
python et_export.py --checkpoint_path $MODEL_REPO/model.pth -d fp32 --xnnpack --out-path ${MODEL_REPO}
88+
python et_export.py --checkpoint_path $MODEL_REPO/model.pth -d fp32 --out-path ${MODEL_REPO}
8989
```
9090

91-
How do run is problematic -- I would love to run it with
91+
TODO(fix this): the export command works with "--xnnpack" flag, but the next generate.py command will not run it so we do not set it right now.
92+
93+
To run the pte file, run this. Note that this is very slow at the moment.
9294
```
93-
python generate.py --pte ./${MODEL_REPO}.pte --prompt "Hello my name is" --device cpu
95+
python generate.py --checkpoint_path $MODEL_REPO/model.pth --pte $MODEL_REPO/llama-fast.pte --prompt "Hello my name is" --device cpu
9496
```
9597
but *that requires xnnpack to work in python!*
9698

@@ -233,6 +235,11 @@ List dependencies for these backends
233235
### ExecuTorch
234236
Set up executorch by following the instructions [here](https://pytorch.org/executorch/stable/getting-started-setup.html#setting-up-executorch).
235237

238+
Make sure when you run the installation script in the executorch repo, you enable pybind.
239+
```
240+
./install_requirements.sh --pybind
241+
```
242+
236243

237244

238245
# Acknowledgements

et_export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,11 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
172172
)
173173
)
174174

175-
save_pte_program(export_program, "llama-fast", output_path)
175+
print("The methods are: ", export_program.methods)
176+
path = f"{output_path}/llama-fast.pte"
177+
with open(path, "wb") as f:
178+
export_program.write_to_file(f)
179+
# save_pte_program(export_program, "llama-fast", output_path)
176180

177181
return output_path
178182

et_wrapper.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@ class PTEModel(nn.Module):
99
def __init__(self, config, path) -> None:
1010
super().__init__()
1111
self.config = config
12-
self.model_ = exec_lib._load_for_executorch(path)
12+
self.model_ = exec_lib._load_for_executorch(str(path))
1313

14-
defccorward(self, x, input_pos):
15-
logits = module.forward(
16-
x.to(torch.long),
17-
input_pos.to(torch.long),
18-
)
14+
def forward(self, x, input_pos):
15+
# model_.forward expects inputs to be wrapped in a tuple
16+
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
17+
logits = self.model_.forward(forward_inputs)
18+
19+
# After wrapping in a tuple, we get a list back, so we need to grab
20+
# the first element to get the tensor
21+
assert len(logits) == 1
22+
logits = logits[0]
1923
return logits
2024

2125
def setup_caches(self, max_batch_size, max_seq_length):

0 commit comments

Comments
 (0)