Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 hybrid implementation using submeshes #18777

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

ipotkonjak-tt
Copy link
Contributor

@ipotkonjak-tt ipotkonjak-tt commented Mar 7, 2025

Problem description

Missing support for data / hybrid parallelism for Llama3 models.

What's changed

Addition of hybrid parallelism within llama code base with concept of submeshes. Implementation is mainly based at the LlamaGenerator level. MeshDevice is partitioned into submeshes where each subset of devices has an independent model. Models remain implemented in the tensor parallel manner.

Checklist

Copy link
Contributor

@yieldthought yieldthought left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean 👌

To do:

  • Add at least one CI test that will exercise DP. I suggest adding a demo to the t3k tests.

@ipotkonjak-tt ipotkonjak-tt requested a review from cfjchu March 7, 2025 14:34
@ipotkonjak-tt ipotkonjak-tt self-assigned this Mar 8, 2025
Comment on lines +366 to +367
"batch-1-DP-4", # DP 4 latency
"batch-1-DP-8", # DP 8 latency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a batch 32 + DP test?

Comment on lines +452 to +453
if is_ci_env and num_devices == 8 and data_parallel > 1 and not ("3.2-1B" in llama_dir or "3.1-8B" in llama_dir):
pytest.skip("CI runs only hybrid Llama3 1b and 8b on T3K")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about 3B?

@@ -335,6 +439,19 @@ def test_llama_demo_text(
]: # If the flag is provided, use it. Take an int instead of bool due to parser limitations
stop_at_eos = request.config.getoption("--stop_at_eos")

num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1
batch_size *= data_parallel # input batch_size is interpreted as size per DP group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can batch_size be renamed for clarity (here and throughout the demo)? e.g. global_batch_size for batch_size * data_parallel

# Hybrid requires a model per submesh
model_args = []
model = []
page_table = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused page_table var (overwritten below)

max_num_blocks=page_params["page_max_num_blocks"],
)
# Implied shuffling of blocks
permutation = torch.randperm(paged_attention_config.max_num_blocks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to this, max_num_blocks now represents the max blocks per dp group right? Can it be renamed to max_num_blocks_per_dp for clarity?

)
model_args.append(model_args_i)
model.append(model_i)

return cls(model, model_args, mesh_device)

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be model_args[0] in cache_path and max_cross_attn_tokens?


def prefill_forward(self, *args, **kwargs):
return super().prefill_forward_text(*args, **kwargs)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)

def allocate_kv_cache(self, *args, **kwargs):
return allocate_kv_cache(*args, **kwargs)


class TtQwen2ForCausalLM(LlamaGenerator):
def __init__(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be model_args[0] in cache_path?


model_args = []
model = []

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing state_dict=None for first loop iter

@@ -20,6 +20,44 @@
from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN_ID, MLLAMA_IMAGE_TOKEN


def generate_submeshes(mesh_device):
data_parallel = int(os.getenv("TT_DATA_PARALLEL", 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you replace this env var with a new arg like tt_data_parallel in the initialize_vllm_model class methods? I think it would be better than propagating an env var all the way here

return data_parallel, mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel))


def allocate_kv_cache(kv_cache_shape, dtype, num_layers, mesh_device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO (@ipotkonjak-tt and/or @skhorasganiTT) Modify KV creation in vLLM to use this function and test with DP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants