Skip to content

[torchtitan] TorchFunctionMode + SAC issue #1434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/XilunWu/23/base
Choose a base branch
from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Jul 21, 2025

Stack from ghstack (oldest at bottom):

Summary
Context Parallel has 2 ways to dispatch SDPA to the corresponding KV all-gather variant: 1. through monkey-patch; 2. using TorchFunctionMode. The first approach works well with SAC but the second is failing with error "RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead."

Command
CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml ./run_train.sh --model.flavor=8B (fail)
CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml ./run_train.sh --model.flavor=8B --activation_checkpoint.mode="none" (success)

Error log

(pytorch-3.12) [[email protected] /data/users/xilunwu/oss/torchtitan (sac_tf_issue_in_view)]$ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml ./run_train.sh --model.flavor=8B 
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 1 -ne 0 ']'
+ overrides=--model.flavor=8B
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml --model.flavor=8B
W0721 14:19:06.593000 3585637 torch/distributed/run.py:774] 
W0721 14:19:06.593000 3585637 torch/distributed/run.py:774] *****************************************
W0721 14:19:06.593000 3585637 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0721 14:19:06.593000 3585637 torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-21 14:19:12,694 - root - INFO - Starting job: Llama 3 8B training
[rank0]:[titan] 2025-07-21 14:19:14,568 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-21 14:19:14,570 - root - INFO - Building 1-D device mesh with ['cp'], [8]
[rank0]:[titan] 2025-07-21 14:19:14,571 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-07-21 14:19:17,944 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-07-21 14:19:18,218 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-07-21 14:19:23,811 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-21 14:19:23,966 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250721-1419
[rank0]:[titan] 2025-07-21 14:19:23,966 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory
[rank0]:[titan] 2025-07-21 14:19:24,006 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-21 14:19:24,007 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-21 14:19:24,074 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-21 14:19:24,074 - root - INFO - Applied Context Parallel to the model
[rank0]:[titan] 2025-07-21 14:19:24,341 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-07-21 14:19:24,341 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:[titan] 2025-07-21 14:19:24,342 - root - WARNING - Warmup steps (200) exceed total training steps (20). Adjusting warmup steps to 20.
[rank0]:[titan] 2025-07-21 14:19:24,342 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-21 14:19:24,342 - root - INFO - Trainer is initialized with local batch size 1, global batch size 1, gradient accumulation steps 1, sequence length 8192, total steps 20 (warmup 200).
[rank0]:[titan] 2025-07-21 14:19:24,343 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-21 14:19:24,343 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:[rank0]:   File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 583, in <module>
[rank0]:[rank0]:     trainer.train()
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
[rank0]:[rank0]:     return f(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 513, in train
[rank0]:[rank0]:     self.train_step(data_iterator)
[rank0]:[rank0]:   File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 445, in train_step
[rank0]:[rank0]:     loss = self.forward_backward_step(input_dict, labels)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 427, in forward_backward_step
[rank0]:[rank0]:     loss.backward()
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_tensor.py", line 616, in backward
[rank0]:[rank0]:     return handle_torch_function(
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/overrides.py", line 1725, in handle_torch_function
[rank0]:[rank0]:     result = mode.__torch_function__(public_api, types, args, kwargs)
[rank0]:[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/distributed/tensor/experimental/_attention.py", line 1304, in __torch_function__
[rank0]:[rank0]:     return func(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_tensor.py", line 625, in backward
[rank0]:[rank0]:     torch.autograd.backward(
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/autograd/__init__.py", line 354, in backward
[rank0]:[rank0]:     _engine_run_backward(
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
[rank0]:[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1146, in unpack_hook
[rank0]:[rank0]:     _run_fn_with_dynamo_disabled(frame.recompute_fn, *args)
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_compile.py", line 53, in inner
[rank0]:[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_dynamo/eval_frame.py", line 1004, in _fn
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1116, in _run_fn_with_dynamo_disabled
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1542, in recompute_fn
[rank0]:[rank0]:     fn(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/llama3/model/model.py", line 300, in forward
[rank0]:[rank0]:     h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/llama3/model/model.py", line 197, in forward
[rank0]:[rank0]:     output = output.view(bs, seqlen, -1)
[rank0]:[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1353, in __torch_dispatch__
[rank0]:[rank0]:     out = func(*args, **kwargs)
[rank0]:[rank0]:           ^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_ops.py", line 840, in __call__
[rank0]:[rank0]:     return self._op(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_compile.py", line 53, in inner
[rank0]:[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_dynamo/eval_frame.py", line 1004, in _fn
[rank0]:[rank0]:     return fn(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/distributed/tensor/_api.py", line 358, in __torch_dispatch__
[rank0]:[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/distributed/tensor/_dispatch.py", line 201, in dispatch
[rank0]:[rank0]:     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank0]:[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/data/users/xilunwu/pytorch/torch/_ops.py", line 840, in __call__
[rank0]:[rank0]:     return self._op(*args, **kwargs)
[rank0]:[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[rank0]: RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
W0721 14:19:28.408000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586048 closing signal SIGTERM
W0721 14:19:28.409000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586050 closing signal SIGTERM
W0721 14:19:28.410000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586051 closing signal SIGTERM
W0721 14:19:28.411000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586052 closing signal SIGTERM
W0721 14:19:28.412000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586053 closing signal SIGTERM
W0721 14:19:28.413000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586054 closing signal SIGTERM
W0721 14:19:28.413000 3585637 torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3586055 closing signal SIGTERM
E0721 14:19:29.791000 3585637 torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 1 (pid: 3586049) of binary: /home/xilunwu/.conda/envs/pytorch-3.12/bin/python
E0721 14:19:29.794000 3585637 torch/distributed/elastic/multiprocessing/errors/error_handler.py:141] no error file defined for parent, to copy child error file (/tmp/torchelastic_oq386bu4/none_sc1fjzmr/attempt_0/1/error.json)
Traceback (most recent call last):
  File "/home/xilunwu/.conda/envs/pytorch-3.12/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xilunwu/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/data/users/xilunwu/pytorch/torch/distributed/run.py", line 901, in main
    run(args)
  File "/data/users/xilunwu/pytorch/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/data/users/xilunwu/pytorch/torch/distributed/launcher/api.py", line 143, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/xilunwu/pytorch/torch/distributed/launcher/api.py", line 277, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
torchtitan.train FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-07-21_14:19:27
  host      : devgpu263.prn2.facebook.com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3586049)
  error_file: /tmp/torchelastic_oq386bu4/none_sc1fjzmr/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/data/users/xilunwu/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 513, in train
      self.train_step(data_iterator)
    File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 445, in train_step
      loss = self.forward_backward_step(input_dict, labels)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/oss/torchtitan/torchtitan/train.py", line 427, in forward_backward_step
      loss.backward()
    File "/data/users/xilunwu/pytorch/torch/_tensor.py", line 616, in backward
      return handle_torch_function(
             ^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/overrides.py", line 1725, in handle_torch_function
      result = mode.__torch_function__(public_api, types, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/distributed/tensor/experimental/_attention.py", line 1304, in __torch_function__
      return func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/_tensor.py", line 625, in backward
      torch.autograd.backward(
    File "/data/users/xilunwu/pytorch/torch/autograd/__init__.py", line 354, in backward
      _engine_run_backward(
    File "/data/users/xilunwu/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1146, in unpack_hook
      _run_fn_with_dynamo_disabled(frame.recompute_fn, *args)
    File "/data/users/xilunwu/pytorch/torch/_compile.py", line 53, in inner
      return disable_fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/_dynamo/eval_frame.py", line 1004, in _fn
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1116, in _run_fn_with_dynamo_disabled
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1542, in recompute_fn
      fn(*args, **kwargs)
    File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/llama3/model/model.py", line 300, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/llama3/model/model.py", line 197, in forward
      output = output.view(bs, seqlen, -1)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/utils/checkpoint.py", line 1353, in __torch_dispatch__
      out = func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/_ops.py", line 840, in __call__
      return self._op(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/_compile.py", line 53, in inner
      return disable_fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/_dynamo/eval_frame.py", line 1004, in _fn
      return fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/distributed/tensor/_api.py", line 358, in __torch_dispatch__
      return DTensor._op_dispatcher.dispatch(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/distributed/tensor/_dispatch.py", line 201, in dispatch
      local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/xilunwu/pytorch/torch/_ops.py", line 840, in __call__
      return self._op(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
  
============================================================
(pytorch-3.12) [[email protected] /data/users/xilunwu/oss/torchtitan (sac_tf_issue_in_view)]$ 

XilunWu added a commit that referenced this pull request Jul 21, 2025
ghstack-source-id: 706b9f4
Pull Request resolved: #1434
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant