-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3 hybrid implementation using submeshes #18777
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean 👌
To do:
- Add at least one CI test that will exercise DP. I suggest adding a demo to the t3k tests.
"batch-1-DP-4", # DP 4 latency | ||
"batch-1-DP-8", # DP 8 latency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a batch 32 + DP test?
if is_ci_env and num_devices == 8 and data_parallel > 1 and not ("3.2-1B" in llama_dir or "3.1-8B" in llama_dir): | ||
pytest.skip("CI runs only hybrid Llama3 1b and 8b on T3K") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about 3B?
@@ -335,6 +439,19 @@ def test_llama_demo_text( | |||
]: # If the flag is provided, use it. Take an int instead of bool due to parser limitations | |||
stop_at_eos = request.config.getoption("--stop_at_eos") | |||
|
|||
num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 | |||
batch_size *= data_parallel # input batch_size is interpreted as size per DP group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can batch_size be renamed for clarity (here and throughout the demo)? e.g. global_batch_size
for batch_size * data_parallel
# Hybrid requires a model per submesh | ||
model_args = [] | ||
model = [] | ||
page_table = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused page_table var (overwritten below)
max_num_blocks=page_params["page_max_num_blocks"], | ||
) | ||
# Implied shuffling of blocks | ||
permutation = torch.randperm(paged_attention_config.max_num_blocks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to this, max_num_blocks
now represents the max blocks per dp group right? Can it be renamed to max_num_blocks_per_dp
for clarity?
) | ||
model_args.append(model_args_i) | ||
model.append(model_i) | ||
|
||
return cls(model, model_args, mesh_device) | ||
|
||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be model_args[0] in cache_path
and max_cross_attn_tokens
?
|
||
def prefill_forward(self, *args, **kwargs): | ||
return super().prefill_forward_text(*args, **kwargs) | ||
|
||
def decode_forward(self, *args, **kwargs): | ||
return super().decode_forward_text(*args, **kwargs) | ||
|
||
def allocate_kv_cache(self, *args, **kwargs): | ||
return allocate_kv_cache(*args, **kwargs) | ||
|
||
|
||
class TtQwen2ForCausalLM(LlamaGenerator): | ||
def __init__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be model_args[0] in cache_path?
|
||
model_args = [] | ||
model = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing state_dict=None
for first loop iter
@@ -20,6 +20,44 @@ | |||
from vllm.model_executor.models.mllama import MLLAMA_IMAGE_TOKEN_ID, MLLAMA_IMAGE_TOKEN | |||
|
|||
|
|||
def generate_submeshes(mesh_device): | |||
data_parallel = int(os.getenv("TT_DATA_PARALLEL", 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you replace this env var with a new arg like tt_data_parallel
in the initialize_vllm_model
class methods? I think it would be better than propagating an env var all the way here
return data_parallel, mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel)) | ||
|
||
|
||
def allocate_kv_cache(kv_cache_shape, dtype, num_layers, mesh_device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO (@ipotkonjak-tt and/or @skhorasganiTT) Modify KV creation in vLLM to use this function and test with DP
Problem description
Missing support for data / hybrid parallelism for Llama3 models.
What's changed
Addition of hybrid parallelism within llama code base with concept of submeshes. Implementation is mainly based at the LlamaGenerator level. MeshDevice is partitioned into submeshes where each subset of devices has an independent model. Models remain implemented in the tensor parallel manner.
Checklist