Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: Executorch export to CoreML and MPS #742

Closed
wants to merge 4 commits into from

Conversation

swolchok
Copy link
Contributor

@swolchok swolchok commented May 10, 2024

Known issues:

  • custom SDPA is necessary to 1) get MPS FP16 to work at all (produces ?? tokens otherwise) 2) get MPS FP32 to perform
  • Neither backend performs well yet
  • Core ML produces nonsense -- stutters
  • Core ML has a 0-size tensor, which causes BNNSCopy to crash; needs the following executorch patch to not crash:
diff --git a/backends/apple/coreml/runtime/delegate/multiarray.mm b/backends/apple/coreml/runtime/delegate/multiarray.mm
index 74996fb8d..94e8e4daa 100644
--- a/backends/apple/coreml/runtime/delegate/multiarray.mm
+++ b/backends/apple/coreml/runtime/delegate/multiarray.mm
@@ -95,7 +95,7 @@ std::optional<BNNSDataType> get_bnns_data_type(MultiArray::DataType datatype) {
 /// @retval `true` if the initialization succeeded otherwise `false`.
 bool init_bnns_descriptor(BNNSNDArrayDescriptor& bnns_descriptor, const MultiArray& multi_array) {
     const auto& layout = multi_array.layout();
-    if (layout.num_elements() == 1) {
+    if (layout.num_elements() <= 1) {
         return false;
     }

swolchok added 2 commits May 10, 2024 15:19
Tried on stories15M and it produces nonsense, but this is the start of integration.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 10, 2024
@swolchok swolchok force-pushed the executorch-export branch from 7257854 to f73d717 Compare May 10, 2024 22:58
@swolchok
Copy link
Contributor Author

Example export command: python3 torchchat.py export stories15M --output-pte-path stories15M_mps.pte --device cpu --executorch-backend mps
To run: python3 torchchat.py generate stories15M --pte-path stories15M_mps.pte --prompt "Hello my name is" --temperature 0 --device cpu --num-samples 3 2>/dev/null (stderr redirect because of executorch logspam using pybindings that we need to fix)

I'm getting 100-110 tokens/sec with MPS export. For comparison, I get 60-68 tokens/sec with MPS eager and over 400 tokens/sec with XNNPACK export.

@swolchok
Copy link
Contributor Author

here is the FP32 exported MPS graph:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_layers_0_attention_kv_cache_k_cache: "f32[1, 8192, 6, 48]", b_layers_0_attention_kv_cache_v_cache: "f32[1, 8192, 6, 48]", b_layers_1_attention_kv_cache_k_cache: "f32[1, 8192, 6, 48]", b_layers_1_attention_kv_cache_v_cache: "f32[1, 8192, 6, 48]", b_layers_2_attention_kv_cache_k_cache: "f32[1, 8192, 6, 48]", b_layers_2_attention_kv_cache_v_cache: "f32[1, 8192, 6, 48]", b_layers_3_attention_kv_cache_k_cache: "f32[1, 8192, 6, 48]", b_layers_3_attention_kv_cache_v_cache: "f32[1, 8192, 6, 48]", b_layers_4_attention_kv_cache_k_cache: "f32[1, 8192, 6, 48]", b_layers_4_attention_kv_cache_v_cache: "f32[1, 8192, 6, 48]", b_layers_5_attention_kv_cache_k_cache: "f32[1, 8192, 6, 48]", b_layers_5_attention_kv_cache_v_cache: "f32[1, 8192, 6, 48]", idx: "i64[1, 1]", input_pos: "i64[1]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, input_pos, idx);  lowered_module_0 = idx = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:78 in forward, code: input_pos[-1],
            getitem: "i64[]" = executorch_call_delegate[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:192 in forward, code: x = self.tok_embeddings(idx)
            getitem_1: "f32[1, 1, 288]" = executorch_call_delegate[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:191 in forward, code: freqs_cis = self.freqs_cis[input_pos]
            getitem_2: "f32[1, 24, 2]" = executorch_call_delegate[2]

            # File: /Users/swolchok/src/torchchat/export_et_util.py:64 in forward, code: v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            getitem_3: "f32[1, 1, 6, 48]" = executorch_call_delegate[3]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_4: "f32[1, 1, 6, 48]" = executorch_call_delegate[4]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_5: "f32[1, 1, 6, 48]" = executorch_call_delegate[5];  executorch_call_delegate = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:72 in forward, code: output = torch.ops.llama.sdpa_with_kv_cache(
            auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.llama.sdpa_with_kv_cache.default, query = getitem_4, key = getitem_5, value = getitem_3, key_cache = b_layers_0_attention_kv_cache_k_cache, value_cache = b_layers_0_attention_kv_cache_v_cache, start_pos = getitem, seq_len = 1, attn_mask = None, drpout_p = 0.0, is_causal = False, scale = None);  getitem_4 = getitem_5 = getitem_3 = b_layers_0_attention_kv_cache_k_cache = b_layers_0_attention_kv_cache_v_cache = getitem = None
            getitem_6: "f32[1, 1, 6, 48]" = auto_functionalized[0]
            getitem_7: "f32[1, 8192, 6, 48]" = auto_functionalized[1]
            getitem_8: "f32[1, 8192, 6, 48]" = auto_functionalized[2];  auto_functionalized = None

            # No stacktrace found for following nodes
            lowered_module_1 = self.lowered_module_1
            executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, input_pos, getitem_2, getitem_6, getitem_1);  lowered_module_1 = getitem_6 = getitem_1 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:78 in forward, code: input_pos[-1],
            getitem_9: "i64[]" = executorch_call_delegate_1[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_10: "f32[1, 1, 288]" = executorch_call_delegate_1[1]

            # File: /Users/swolchok/src/torchchat/export_et_util.py:64 in forward, code: v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            getitem_11: "f32[1, 1, 6, 48]" = executorch_call_delegate_1[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_12: "f32[1, 1, 6, 48]" = executorch_call_delegate_1[3]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_13: "f32[1, 1, 6, 48]" = executorch_call_delegate_1[4];  executorch_call_delegate_1 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:72 in forward, code: output = torch.ops.llama.sdpa_with_kv_cache(
            auto_functionalized_1 = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.llama.sdpa_with_kv_cache.default, query = getitem_12, key = getitem_13, value = getitem_11, key_cache = b_layers_1_attention_kv_cache_k_cache, value_cache = b_layers_1_attention_kv_cache_v_cache, start_pos = getitem_9, seq_len = 1, attn_mask = None, drpout_p = 0.0, is_causal = False, scale = None);  getitem_12 = getitem_13 = getitem_11 = b_layers_1_attention_kv_cache_k_cache = b_layers_1_attention_kv_cache_v_cache = getitem_9 = None
            getitem_14: "f32[1, 1, 6, 48]" = auto_functionalized_1[0]
            getitem_15: "f32[1, 8192, 6, 48]" = auto_functionalized_1[1]
            getitem_16: "f32[1, 8192, 6, 48]" = auto_functionalized_1[2];  auto_functionalized_1 = None

            # No stacktrace found for following nodes
            lowered_module_2 = self.lowered_module_2
            executorch_call_delegate_2 = torch.ops.higher_order.executorch_call_delegate(lowered_module_2, input_pos, getitem_2, getitem_14, getitem_10);  lowered_module_2 = getitem_14 = getitem_10 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:78 in forward, code: input_pos[-1],
            getitem_17: "i64[]" = executorch_call_delegate_2[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_18: "f32[1, 1, 288]" = executorch_call_delegate_2[1]

            # File: /Users/swolchok/src/torchchat/export_et_util.py:64 in forward, code: v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            getitem_19: "f32[1, 1, 6, 48]" = executorch_call_delegate_2[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_20: "f32[1, 1, 6, 48]" = executorch_call_delegate_2[3]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_21: "f32[1, 1, 6, 48]" = executorch_call_delegate_2[4];  executorch_call_delegate_2 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:72 in forward, code: output = torch.ops.llama.sdpa_with_kv_cache(
            auto_functionalized_2 = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.llama.sdpa_with_kv_cache.default, query = getitem_20, key = getitem_21, value = getitem_19, key_cache = b_layers_2_attention_kv_cache_k_cache, value_cache = b_layers_2_attention_kv_cache_v_cache, start_pos = getitem_17, seq_len = 1, attn_mask = None, drpout_p = 0.0, is_causal = False, scale = None);  getitem_20 = getitem_21 = getitem_19 = b_layers_2_attention_kv_cache_k_cache = b_layers_2_attention_kv_cache_v_cache = getitem_17 = None
            getitem_22: "f32[1, 1, 6, 48]" = auto_functionalized_2[0]
            getitem_23: "f32[1, 8192, 6, 48]" = auto_functionalized_2[1]
            getitem_24: "f32[1, 8192, 6, 48]" = auto_functionalized_2[2];  auto_functionalized_2 = None

            # No stacktrace found for following nodes
            lowered_module_3 = self.lowered_module_3
            executorch_call_delegate_3 = torch.ops.higher_order.executorch_call_delegate(lowered_module_3, input_pos, getitem_2, getitem_22, getitem_18);  lowered_module_3 = getitem_22 = getitem_18 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:78 in forward, code: input_pos[-1],
            getitem_25: "i64[]" = executorch_call_delegate_3[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_26: "f32[1, 1, 288]" = executorch_call_delegate_3[1]

            # File: /Users/swolchok/src/torchchat/export_et_util.py:64 in forward, code: v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            getitem_27: "f32[1, 1, 6, 48]" = executorch_call_delegate_3[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_28: "f32[1, 1, 6, 48]" = executorch_call_delegate_3[3]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_29: "f32[1, 1, 6, 48]" = executorch_call_delegate_3[4];  executorch_call_delegate_3 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:72 in forward, code: output = torch.ops.llama.sdpa_with_kv_cache(
            auto_functionalized_3 = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.llama.sdpa_with_kv_cache.default, query = getitem_28, key = getitem_29, value = getitem_27, key_cache = b_layers_3_attention_kv_cache_k_cache, value_cache = b_layers_3_attention_kv_cache_v_cache, start_pos = getitem_25, seq_len = 1, attn_mask = None, drpout_p = 0.0, is_causal = False, scale = None);  getitem_28 = getitem_29 = getitem_27 = b_layers_3_attention_kv_cache_k_cache = b_layers_3_attention_kv_cache_v_cache = getitem_25 = None
            getitem_30: "f32[1, 1, 6, 48]" = auto_functionalized_3[0]
            getitem_31: "f32[1, 8192, 6, 48]" = auto_functionalized_3[1]
            getitem_32: "f32[1, 8192, 6, 48]" = auto_functionalized_3[2];  auto_functionalized_3 = None

            # No stacktrace found for following nodes
            lowered_module_4 = self.lowered_module_4
            executorch_call_delegate_4 = torch.ops.higher_order.executorch_call_delegate(lowered_module_4, input_pos, getitem_2, getitem_30, getitem_26);  lowered_module_4 = getitem_30 = getitem_26 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:78 in forward, code: input_pos[-1],
            getitem_33: "i64[]" = executorch_call_delegate_4[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_34: "f32[1, 1, 288]" = executorch_call_delegate_4[1]

            # File: /Users/swolchok/src/torchchat/export_et_util.py:64 in forward, code: v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            getitem_35: "f32[1, 1, 6, 48]" = executorch_call_delegate_4[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_36: "f32[1, 1, 6, 48]" = executorch_call_delegate_4[3]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_37: "f32[1, 1, 6, 48]" = executorch_call_delegate_4[4];  executorch_call_delegate_4 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:72 in forward, code: output = torch.ops.llama.sdpa_with_kv_cache(
            auto_functionalized_4 = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.llama.sdpa_with_kv_cache.default, query = getitem_36, key = getitem_37, value = getitem_35, key_cache = b_layers_4_attention_kv_cache_k_cache, value_cache = b_layers_4_attention_kv_cache_v_cache, start_pos = getitem_33, seq_len = 1, attn_mask = None, drpout_p = 0.0, is_causal = False, scale = None);  getitem_36 = getitem_37 = getitem_35 = b_layers_4_attention_kv_cache_k_cache = b_layers_4_attention_kv_cache_v_cache = getitem_33 = None
            getitem_38: "f32[1, 1, 6, 48]" = auto_functionalized_4[0]
            getitem_39: "f32[1, 8192, 6, 48]" = auto_functionalized_4[1]
            getitem_40: "f32[1, 8192, 6, 48]" = auto_functionalized_4[2];  auto_functionalized_4 = None

            # No stacktrace found for following nodes
            lowered_module_5 = self.lowered_module_5
            executorch_call_delegate_5 = torch.ops.higher_order.executorch_call_delegate(lowered_module_5, getitem_38, getitem_2, input_pos, getitem_34);  lowered_module_5 = getitem_38 = getitem_2 = input_pos = getitem_34 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:78 in forward, code: input_pos[-1],
            getitem_41: "i64[]" = executorch_call_delegate_5[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_42: "f32[1, 1, 288]" = executorch_call_delegate_5[1]

            # File: /Users/swolchok/src/torchchat/export_et_util.py:64 in forward, code: v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            getitem_43: "f32[1, 1, 6, 48]" = executorch_call_delegate_5[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_44: "f32[1, 1, 6, 48]" = executorch_call_delegate_5[3]

            # File: /Users/swolchok/src/torchchat/build/model.py:390 in apply_rotary_emb, code: x_out2 = x_out2.flatten(3)
            getitem_45: "f32[1, 1, 6, 48]" = executorch_call_delegate_5[4];  executorch_call_delegate_5 = None

            # File: /Users/swolchok/src/torchchat/export_et_util.py:72 in forward, code: output = torch.ops.llama.sdpa_with_kv_cache(
            auto_functionalized_5 = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.llama.sdpa_with_kv_cache.default, query = getitem_44, key = getitem_45, value = getitem_43, key_cache = b_layers_5_attention_kv_cache_k_cache, value_cache = b_layers_5_attention_kv_cache_v_cache, start_pos = getitem_41, seq_len = 1, attn_mask = None, drpout_p = 0.0, is_causal = False, scale = None);  getitem_44 = getitem_45 = getitem_43 = b_layers_5_attention_kv_cache_k_cache = b_layers_5_attention_kv_cache_v_cache = getitem_41 = None
            getitem_46: "f32[1, 1, 6, 48]" = auto_functionalized_5[0]
            getitem_47: "f32[1, 8192, 6, 48]" = auto_functionalized_5[1]
            getitem_48: "f32[1, 8192, 6, 48]" = auto_functionalized_5[2];  auto_functionalized_5 = None

            # No stacktrace found for following nodes
            lowered_module_6 = self.lowered_module_6
            executorch_call_delegate_6 = torch.ops.higher_order.executorch_call_delegate(lowered_module_6, getitem_46, getitem_42);  lowered_module_6 = getitem_46 = getitem_42 = None
            getitem_49: "f32[1, 1, 32000]" = executorch_call_delegate_6[0];  executorch_call_delegate_6 = None
            return (getitem_7, getitem_8, getitem_15, getitem_16, getitem_23, getitem_24, getitem_31, getitem_32, getitem_39, getitem_40, getitem_47, getitem_48, getitem_49)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_0_attention_kv_cache_k_cache'), target='layers_0_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_0_attention_kv_cache_v_cache'), target='layers_0_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_1_attention_kv_cache_k_cache'), target='layers_1_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_1_attention_kv_cache_v_cache'), target='layers_1_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_2_attention_kv_cache_k_cache'), target='layers_2_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_2_attention_kv_cache_v_cache'), target='layers_2_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_3_attention_kv_cache_k_cache'), target='layers_3_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_3_attention_kv_cache_v_cache'), target='layers_3_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_4_attention_kv_cache_k_cache'), target='layers_4_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_4_attention_kv_cache_v_cache'), target='layers_4_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_5_attention_kv_cache_k_cache'), target='layers_5_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_5_attention_kv_cache_v_cache'), target='layers_5_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='idx'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_pos'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_7'), target='layers_0_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_8'), target='layers_0_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_15'), target='layers_1_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_16'), target='layers_1_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_23'), target='layers_2_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_24'), target='layers_2_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_31'), target='layers_3_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_32'), target='layers_3_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_39'), target='layers_4_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_40'), target='layers_4_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_47'), target='layers_5_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_48'), target='layers_5_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_49'), target=None)])
Range constraints: {}

@@ -97,7 +98,30 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
dynamic_shapes=dynamic_shapes,
edge_compile_config=edge_config,
)
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())
if backend == "xnnpack":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is it possible to do something like: https://github.com/pytorch/executorch/blob/main/examples/models/llama2/export_llama_lib.py#L393-L407 so that it's easier to read

@cccclai
Copy link

cccclai commented May 10, 2024

Is there a comment to generate stories15M_mps.pte ?

@swolchok
Copy link
Contributor Author

with replace_attention_with_custom_sdpa_attention(model) commented out in export_et_util, here is the final exported_program generating stories15M (python3 torchchat.py export stories15M --output-pte-path stories15M_mps.nocustomop.pte --device cpu --executorch-backend mps, to avoid confusion)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_layers_0_attention_kv_cache_k_cache: "f32[1, 6, 8192, 48]", b_layers_0_attention_kv_cache_v_cache: "f32[1, 6, 8192, 48]", b_layers_1_attention_kv_cache_k_cache: "f32[1, 6, 8192, 48]", b_layers_1_attention_kv_cache_v_cache: "f32[1, 6, 8192, 48]", b_layers_2_attention_kv_cache_k_cache: "f32[1, 6, 8192, 48]", b_layers_2_attention_kv_cache_v_cache: "f32[1, 6, 8192, 48]", b_layers_3_attention_kv_cache_k_cache: "f32[1, 6, 8192, 48]", b_layers_3_attention_kv_cache_v_cache: "f32[1, 6, 8192, 48]", b_layers_4_attention_kv_cache_k_cache: "f32[1, 6, 8192, 48]", b_layers_4_attention_kv_cache_v_cache: "f32[1, 6, 8192, 48]", b_layers_5_attention_kv_cache_k_cache: "f32[1, 6, 8192, 48]", b_layers_5_attention_kv_cache_v_cache: "f32[1, 6, 8192, 48]", idx: "i64[1, 1]", input_pos: "i64[1]"):
            # File: /Users/swolchok/src/torchchat/build/model.py:331 in forward, code: y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
            scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
            scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
            scalar_tensor_2: "f32[]" = torch.ops.aten.scalar_tensor.default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
            scalar_tensor_3: "f32[]" = torch.ops.aten.scalar_tensor.default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
            scalar_tensor_4: "f32[]" = torch.ops.aten.scalar_tensor.default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
            scalar_tensor_5: "f32[]" = torch.ops.aten.scalar_tensor.default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))

            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, idx, input_pos);  lowered_module_0 = idx = None

            # File: /Users/swolchok/src/torchchat/build/model.py:192 in forward, code: x = self.tok_embeddings(idx)
            getitem: "f32[1, 1, 288]" = executorch_call_delegate[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:191 in forward, code: freqs_cis = self.freqs_cis[input_pos]
            getitem_1: "f32[1, 24, 2]" = executorch_call_delegate[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:361 in forward, code: return output * self.weight
            getitem_2: "f32[1, 1, 288]" = executorch_call_delegate[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:324 in <genexpr>, code: q, k, v = (x.transpose(1, 2) for x in (q, k, v))
            getitem_3: "f32[1, 6, 1, 48]" = executorch_call_delegate[3]
            getitem_4: "f32[1, 6, 1, 48]" = executorch_call_delegate[4];  executorch_call_delegate = None

            # File: /Users/swolchok/src/torchchat/build/model.py:140 in update, code: v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
            aten_index_put_default: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_0_attention_kv_cache_v_cache, [None, None, input_pos], getitem_3);  b_layers_0_attention_kv_cache_v_cache = getitem_3 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:139 in update, code: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
            aten_index_put_default_1: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_0_attention_kv_cache_k_cache, [None, None, input_pos], getitem_4);  b_layers_0_attention_kv_cache_k_cache = getitem_4 = None

            # No stacktrace found for following nodes
            lowered_module_1 = self.lowered_module_1
            executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, getitem_1, getitem_2, aten_index_put_default, aten_index_put_default_1, input_pos, scalar_tensor, getitem);  lowered_module_1 = getitem_2 = scalar_tensor = getitem = None

            # File: /Users/swolchok/src/torchchat/build/model.py:190 in forward, code: mask = self.causal_mask[None, None, input_pos]
            getitem_5: "b8[1, 1, 1, 8192]" = executorch_call_delegate_1[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_6: "f32[1, 1, 288]" = executorch_call_delegate_1[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:361 in forward, code: return output * self.weight
            getitem_7: "f32[1, 1, 288]" = executorch_call_delegate_1[2]

            # File: /Users/swolchok/src/torchchat/build/model.py:324 in <genexpr>, code: q, k, v = (x.transpose(1, 2) for x in (q, k, v))
            getitem_8: "f32[1, 6, 1, 48]" = executorch_call_delegate_1[3]
            getitem_9: "f32[1, 6, 1, 48]" = executorch_call_delegate_1[4];  executorch_call_delegate_1 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:140 in update, code: v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
            aten_index_put_default_2: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_1_attention_kv_cache_v_cache, [None, None, input_pos], getitem_8);  b_layers_1_attention_kv_cache_v_cache = getitem_8 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:139 in update, code: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
            aten_index_put_default_3: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_1_attention_kv_cache_k_cache, [None, None, input_pos], getitem_9);  b_layers_1_attention_kv_cache_k_cache = getitem_9 = None

            # No stacktrace found for following nodes
            lowered_module_2 = self.lowered_module_2
            executorch_call_delegate_2 = torch.ops.higher_order.executorch_call_delegate(lowered_module_2, getitem_1, getitem_5, getitem_7, aten_index_put_default_2, aten_index_put_default_3, scalar_tensor_1, getitem_6);  lowered_module_2 = getitem_7 = scalar_tensor_1 = getitem_6 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_10: "f32[1, 1, 288]" = executorch_call_delegate_2[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:361 in forward, code: return output * self.weight
            getitem_11: "f32[1, 1, 288]" = executorch_call_delegate_2[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:324 in <genexpr>, code: q, k, v = (x.transpose(1, 2) for x in (q, k, v))
            getitem_12: "f32[1, 6, 1, 48]" = executorch_call_delegate_2[2]
            getitem_13: "f32[1, 6, 1, 48]" = executorch_call_delegate_2[3];  executorch_call_delegate_2 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:140 in update, code: v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
            aten_index_put_default_4: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_2_attention_kv_cache_v_cache, [None, None, input_pos], getitem_12);  b_layers_2_attention_kv_cache_v_cache = getitem_12 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:139 in update, code: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
            aten_index_put_default_5: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_2_attention_kv_cache_k_cache, [None, None, input_pos], getitem_13);  b_layers_2_attention_kv_cache_k_cache = getitem_13 = None

            # No stacktrace found for following nodes
            lowered_module_3 = self.lowered_module_3
            executorch_call_delegate_3 = torch.ops.higher_order.executorch_call_delegate(lowered_module_3, getitem_1, getitem_5, getitem_11, aten_index_put_default_4, aten_index_put_default_5, scalar_tensor_2, getitem_10);  lowered_module_3 = getitem_11 = scalar_tensor_2 = getitem_10 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_14: "f32[1, 1, 288]" = executorch_call_delegate_3[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:361 in forward, code: return output * self.weight
            getitem_15: "f32[1, 1, 288]" = executorch_call_delegate_3[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:324 in <genexpr>, code: q, k, v = (x.transpose(1, 2) for x in (q, k, v))
            getitem_16: "f32[1, 6, 1, 48]" = executorch_call_delegate_3[2]
            getitem_17: "f32[1, 6, 1, 48]" = executorch_call_delegate_3[3];  executorch_call_delegate_3 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:140 in update, code: v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
            aten_index_put_default_6: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_3_attention_kv_cache_v_cache, [None, None, input_pos], getitem_16);  b_layers_3_attention_kv_cache_v_cache = getitem_16 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:139 in update, code: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
            aten_index_put_default_7: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_3_attention_kv_cache_k_cache, [None, None, input_pos], getitem_17);  b_layers_3_attention_kv_cache_k_cache = getitem_17 = None

            # No stacktrace found for following nodes
            lowered_module_4 = self.lowered_module_4
            executorch_call_delegate_4 = torch.ops.higher_order.executorch_call_delegate(lowered_module_4, getitem_1, getitem_5, getitem_15, aten_index_put_default_6, aten_index_put_default_7, scalar_tensor_3, getitem_14);  lowered_module_4 = getitem_15 = scalar_tensor_3 = getitem_14 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_18: "f32[1, 1, 288]" = executorch_call_delegate_4[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:361 in forward, code: return output * self.weight
            getitem_19: "f32[1, 1, 288]" = executorch_call_delegate_4[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:324 in <genexpr>, code: q, k, v = (x.transpose(1, 2) for x in (q, k, v))
            getitem_20: "f32[1, 6, 1, 48]" = executorch_call_delegate_4[2]
            getitem_21: "f32[1, 6, 1, 48]" = executorch_call_delegate_4[3];  executorch_call_delegate_4 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:140 in update, code: v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
            aten_index_put_default_8: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_4_attention_kv_cache_v_cache, [None, None, input_pos], getitem_20);  b_layers_4_attention_kv_cache_v_cache = getitem_20 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:139 in update, code: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
            aten_index_put_default_9: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_4_attention_kv_cache_k_cache, [None, None, input_pos], getitem_21);  b_layers_4_attention_kv_cache_k_cache = getitem_21 = None

            # No stacktrace found for following nodes
            lowered_module_5 = self.lowered_module_5
            executorch_call_delegate_5 = torch.ops.higher_order.executorch_call_delegate(lowered_module_5, getitem_1, getitem_5, getitem_19, aten_index_put_default_8, aten_index_put_default_9, scalar_tensor_4, getitem_18);  lowered_module_5 = getitem_19 = scalar_tensor_4 = getitem_18 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:235 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
            getitem_22: "f32[1, 1, 288]" = executorch_call_delegate_5[0]

            # File: /Users/swolchok/src/torchchat/build/model.py:361 in forward, code: return output * self.weight
            getitem_23: "f32[1, 1, 288]" = executorch_call_delegate_5[1]

            # File: /Users/swolchok/src/torchchat/build/model.py:324 in <genexpr>, code: q, k, v = (x.transpose(1, 2) for x in (q, k, v))
            getitem_24: "f32[1, 6, 1, 48]" = executorch_call_delegate_5[2]
            getitem_25: "f32[1, 6, 1, 48]" = executorch_call_delegate_5[3];  executorch_call_delegate_5 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:140 in update, code: v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
            aten_index_put_default_10: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_5_attention_kv_cache_v_cache, [None, None, input_pos], getitem_24);  b_layers_5_attention_kv_cache_v_cache = getitem_24 = None

            # File: /Users/swolchok/src/torchchat/build/model.py:139 in update, code: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
            aten_index_put_default_11: "f32[1, 6, 8192, 48]" = executorch_exir_dialects_edge__ops_aten_index_put_default(b_layers_5_attention_kv_cache_k_cache, [None, None, input_pos], getitem_25);  b_layers_5_attention_kv_cache_k_cache = input_pos = getitem_25 = None

            # No stacktrace found for following nodes
            lowered_module_6 = self.lowered_module_6
            executorch_call_delegate_6 = torch.ops.higher_order.executorch_call_delegate(lowered_module_6, getitem_23, getitem_1, aten_index_put_default_11, aten_index_put_default_10, getitem_5, scalar_tensor_5, getitem_22);  lowered_module_6 = getitem_23 = getitem_1 = getitem_5 = scalar_tensor_5 = getitem_22 = None
            getitem_26: "f32[1, 1, 32000]" = executorch_call_delegate_6[0];  executorch_call_delegate_6 = None
            return (aten_index_put_default_1, aten_index_put_default, aten_index_put_default_3, aten_index_put_default_2, aten_index_put_default_5, aten_index_put_default_4, aten_index_put_default_7, aten_index_put_default_6, aten_index_put_default_9, aten_index_put_default_8, aten_index_put_default_11, aten_index_put_default_10, getitem_26)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_0_attention_kv_cache_k_cache'), target='layers_0_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_0_attention_kv_cache_v_cache'), target='layers_0_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_1_attention_kv_cache_k_cache'), target='layers_1_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_1_attention_kv_cache_v_cache'), target='layers_1_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_2_attention_kv_cache_k_cache'), target='layers_2_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_2_attention_kv_cache_v_cache'), target='layers_2_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_3_attention_kv_cache_k_cache'), target='layers_3_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_3_attention_kv_cache_v_cache'), target='layers_3_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_4_attention_kv_cache_k_cache'), target='layers_4_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_4_attention_kv_cache_v_cache'), target='layers_4_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_5_attention_kv_cache_k_cache'), target='layers_5_attention_kv_cache_k_cache', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_layers_5_attention_kv_cache_v_cache'), target='layers_5_attention_kv_cache_v_cache', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='idx'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_pos'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_1'), target='layers_0_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default'), target='layers_0_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_3'), target='layers_1_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_2'), target='layers_1_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_5'), target='layers_2_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_4'), target='layers_2_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_7'), target='layers_3_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_6'), target='layers_3_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_9'), target='layers_4_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_8'), target='layers_4_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_11'), target='layers_5_attention_kv_cache_k_cache'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_index_put_default_10'), target='layers_5_attention_kv_cache_v_cache'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_26'), target=None)])
Range constraints: {}

@swolchok
Copy link
Contributor Author

graph breaks with the custom SDPA op removed are fixed if I manually patch the indexput portion of pytorch/executorch#3399 to executorch. however, still getting garbage MPS FP16 result and 55 tokens/sec with FP32 + MPS.

@swolchok swolchok force-pushed the executorch-export branch from f73d717 to c14845d Compare May 11, 2024 00:15
@swolchok
Copy link
Contributor Author

torchchat/build/model.py and executorch/examples/model/llama_transformer.py are divergent. need to reconcile why they're different (and maybe debug CoreML backend with the current torchchat copy?)

Copy link

pytorch-bot bot commented May 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/742

Note: Links to docs will display an error until the docs builds have been completed.

❌ 25 New Failures, 1 Pending, 6 Unrelated Failures

As of commit 2385491 with merge base 6455aa2 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Jack-Khuu
Copy link
Contributor

Closing old PRs to increase attention of new PRs

@Jack-Khuu Jack-Khuu closed this Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants