-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3 hybrid implementation using submeshes #18777
base: main
Are you sure you want to change the base?
Changes from 14 commits
0bb92df
3dadf7d
e8ced18
f0b7794
ef4fdd0
be91311
2f0780d
cbe7ce6
2856a62
39b9de7
e6a084e
4676286
e8e6ddf
7321390
cbb4a66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,6 +109,7 @@ def create_tt_model( | |
page_params, | ||
dtype=ttnn.bfloat8_b, | ||
use_paged_kv_cache=False, | ||
state_dict=None, | ||
): | ||
from models.demos.llama3.tt.llama_model import TtTransformer | ||
from models.demos.llama3.tt.model_config import TtModelArgs | ||
|
@@ -120,9 +121,11 @@ def create_tt_model( | |
optimizations=optimizations, | ||
max_seq_len=max_seq_len, | ||
) | ||
state_dict = tt_model_args.load_state_dict() | ||
|
||
page_table = None | ||
# Avoid loading state_dict for every DP model | ||
if not state_dict: | ||
state_dict = tt_model_args.load_state_dict() | ||
|
||
paged_attention_config = None | ||
tt_kv_cache = None | ||
|
||
|
@@ -131,17 +134,6 @@ def create_tt_model( | |
block_size=page_params["page_block_size"], | ||
max_num_blocks=page_params["page_max_num_blocks"], | ||
) | ||
# Implied shuffling of blocks | ||
permutation = torch.randperm(paged_attention_config.max_num_blocks) | ||
# Page table which maps virtual blocks to physical | ||
reverse_permutation = torch.argsort(permutation) | ||
page_table = reverse_permutation.reshape( | ||
tt_model_args.max_batch_size, paged_attention_config.max_num_blocks // tt_model_args.max_batch_size | ||
) | ||
paged_attention_config = PagedAttentionConfig( | ||
block_size=page_params["page_block_size"], | ||
max_num_blocks=page_params["page_max_num_blocks"], | ||
) | ||
|
||
model = TtTransformer( | ||
args=tt_model_args, | ||
|
@@ -155,7 +147,82 @@ def create_tt_model( | |
if use_paged_kv_cache: | ||
tt_kv_cache = [l.attention.layer_past for l in model.layers] | ||
|
||
return tt_model_args, model, page_table, tt_kv_cache | ||
return tt_model_args, model, tt_kv_cache, state_dict | ||
|
||
|
||
def create_tt_page_table(max_batch_size, data_parallel, page_params, use_paged_kv_cache): | ||
page_table = None | ||
paged_attention_config = None | ||
|
||
if use_paged_kv_cache: | ||
paged_attention_config = PagedAttentionConfig( | ||
block_size=page_params["page_block_size"], | ||
max_num_blocks=page_params["page_max_num_blocks"], | ||
) | ||
# Implied shuffling of blocks | ||
permutation = torch.randperm(paged_attention_config.max_num_blocks) | ||
# Page table which maps virtual blocks to physical | ||
reverse_permutation = torch.argsort(permutation).repeat(data_parallel) | ||
page_table = reverse_permutation.reshape( | ||
max_batch_size, paged_attention_config.max_num_blocks // (max_batch_size // data_parallel) | ||
) | ||
paged_attention_config = PagedAttentionConfig( | ||
block_size=page_params["page_block_size"], | ||
max_num_blocks=page_params["page_max_num_blocks"], | ||
) | ||
return page_table | ||
|
||
|
||
def prepare_generator_args( | ||
num_devices, | ||
data_parallel, | ||
mesh_device, | ||
instruct, | ||
batch_size, | ||
optimizations, | ||
max_seq_len, | ||
page_params, | ||
paged_attention, | ||
): | ||
# Partition the mesh, singular model implemented for TP on 1xN mesh | ||
submesh_devices = ( | ||
mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel)) | ||
if isinstance(mesh_device, ttnn.MeshDevice) and data_parallel > 1 | ||
else [mesh_device] | ||
) | ||
state_dict = None | ||
|
||
# Hybrid requires a model per submesh | ||
model_args = [] | ||
model = [] | ||
page_table = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused page_table var (overwritten below) |
||
tt_kv_cache = [] | ||
|
||
for submesh in submesh_devices: | ||
model_args_i, model_i, tt_kv_cache_i, state_dict = create_tt_model( | ||
submesh, | ||
instruct=instruct, | ||
max_batch_size=batch_size // data_parallel, | ||
optimizations=optimizations, | ||
max_seq_len=max_seq_len, | ||
page_params=page_params, | ||
dtype=ttnn.bfloat8_b, | ||
use_paged_kv_cache=paged_attention, | ||
state_dict=state_dict, | ||
) | ||
model_args.append(model_args_i) | ||
model.append(model_i) | ||
tt_kv_cache.append(tt_kv_cache_i) | ||
|
||
page_table = create_tt_page_table( | ||
max_batch_size=batch_size, | ||
data_parallel=data_parallel, | ||
page_params=page_params, | ||
use_paged_kv_cache=paged_attention, | ||
) | ||
# Host code, safe to reuse tokenizer from the 1st model | ||
tokenizer = model_args[0].tokenizer | ||
return model_args, model, page_table, tt_kv_cache, tokenizer | ||
|
||
|
||
# List of supported Parameters for demo.py | ||
|
@@ -174,7 +241,7 @@ def create_tt_model( | |
# optimization (LlamaOptimizations): Optimization level to use for the model (performance or accuracy) | ||
# FAKE_DEVICE (str): Fake device to use for testing (N150, N300, T3K, TG). Usage: `export FAKE_DEVICE=N150`, will enable running a single-chip demo on a multi-chip system. | ||
@pytest.mark.parametrize( | ||
"input_prompts, instruct, repeat_batches, max_seq_len, batch_size, max_generated_tokens, paged_attention, page_params, sampling_params, stop_at_eos, ci_only", | ||
"input_prompts, instruct, repeat_batches, max_seq_len, batch_size, max_generated_tokens, paged_attention, page_params, sampling_params, stop_at_eos, ci_only, data_parallel", | ||
[ | ||
( # Batch-1 run (Latency) - single user, small prompt | ||
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts | ||
|
@@ -188,6 +255,7 @@ def create_tt_model( | |
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
True, # stop_at_eos | ||
False, # ci_only | ||
1, | ||
), | ||
( # Batch-32 run (Throughput) - 32 users, small prompt | ||
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts | ||
|
@@ -201,6 +269,7 @@ def create_tt_model( | |
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
True, # stop_at_eos | ||
False, # ci_only | ||
1, # data_parallel | ||
), | ||
( # Long-context run - Single user, long prompt (adapted to the model being used and architecture) | ||
"models/demos/llama3/demo/sample_prompts/input_data_long_64k.json", # input_prompts | ||
|
@@ -214,6 +283,7 @@ def create_tt_model( | |
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
True, # stop_at_eos | ||
False, # ci_only | ||
1, # data_parallel | ||
), | ||
( # Batch-1 run (Reasoning) - single user, small prompt, long thinking time | ||
"models/demos/llama3/demo/input_data_questions_reasoning.json", # input_prompts | ||
|
@@ -227,6 +297,7 @@ def create_tt_model( | |
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
False, # stop_at_eos | ||
False, # ci_only | ||
1, # data_parallel | ||
), | ||
( # CI Batch-1 run - Measures the performance of a single user over 4096 iterations | ||
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts | ||
|
@@ -240,6 +311,7 @@ def create_tt_model( | |
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
False, # stop_at_eos | ||
True, # ci_only | ||
1, # data_parallel | ||
), | ||
( # CI Batch-32 run - Measures the performance of a 32 users over 4096 iterations | ||
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts | ||
|
@@ -253,6 +325,35 @@ def create_tt_model( | |
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
False, # stop_at_eos | ||
True, # ci_only | ||
1, # data_parallel | ||
), | ||
( # Batch-1 run (Latency) - single user, small prompt | ||
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts | ||
True, # instruct mode | ||
1, # repeat_batches | ||
1024, # max_seq_len | ||
1, # batch_size | ||
200, # max_generated_tokens | ||
True, # paged_attention | ||
{"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params | ||
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
True, # stop_at_eos | ||
False, # ci_only | ||
4, # data_parallel | ||
), | ||
( # Batch-1 run (Latency) - single user, small prompt | ||
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts | ||
True, # instruct mode | ||
1, # repeat_batches | ||
1024, # max_seq_len | ||
1, # batch_size | ||
200, # max_generated_tokens | ||
True, # paged_attention | ||
{"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params | ||
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) | ||
True, # stop_at_eos | ||
False, # ci_only | ||
8, # data_parallel | ||
), | ||
], | ||
ids=[ | ||
|
@@ -262,6 +363,8 @@ def create_tt_model( | |
"reasoning-1", # reasoning | ||
"ci-1", # CI batch 1 | ||
"ci-32", # CI batch 32 | ||
"batch-1-DP-4", # DP 4 latency | ||
"batch-1-DP-8", # DP 8 latency | ||
Comment on lines
+366
to
+367
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a batch 32 + DP test? |
||
], | ||
) | ||
@pytest.mark.parametrize( | ||
|
@@ -297,6 +400,7 @@ def test_llama_demo_text( | |
use_program_cache, | ||
is_ci_env, | ||
ci_only, | ||
data_parallel, | ||
reset_seeds, | ||
request, | ||
): | ||
|
@@ -335,6 +439,19 @@ def test_llama_demo_text( | |
]: # If the flag is provided, use it. Take an int instead of bool due to parser limitations | ||
stop_at_eos = request.config.getoption("--stop_at_eos") | ||
|
||
num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 | ||
batch_size *= data_parallel # input batch_size is interpreted as size per DP group | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can batch_size be renamed for clarity (here and throughout the demo)? e.g. |
||
|
||
# uneven split of devices per DP group not supported | ||
if data_parallel > num_devices or num_devices % data_parallel != 0: | ||
pytest.skip(f"Invalid number of DP groups: {data_parallel}, for {num_devices} devices") | ||
|
||
llama_dir = os.getenv("LLAMA_DIR") | ||
if is_ci_env and num_devices == 32 and (data_parallel > 4 or (data_parallel == 4 and "3.1-70B" not in llama_dir)): | ||
pytest.skip("CI runs only Llama3 70b DP = 4, TP = 8 on TG") | ||
if is_ci_env and num_devices == 8 and data_parallel > 1 and not ("3.2-1B" in llama_dir or "3.1-8B" in llama_dir): | ||
pytest.skip("CI runs only hybrid Llama3 1b and 8b on T3K") | ||
Comment on lines
+452
to
+453
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about 3B? |
||
|
||
if not stop_at_eos: | ||
logger.info(f"The decode generation will only stop at the max_generated_tokens limit == {max_generated_tokens}") | ||
|
||
|
@@ -366,18 +483,17 @@ def test_llama_demo_text( | |
for i in range(repeat_batches): | ||
repeat_batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))]) | ||
|
||
model_args, model, page_table, tt_kv_cache = create_tt_model( | ||
mesh_device, | ||
model_args, model, page_table, tt_kv_cache, tokenizer = prepare_generator_args( | ||
num_devices=num_devices, | ||
data_parallel=data_parallel, | ||
mesh_device=mesh_device, | ||
instruct=instruct, | ||
max_batch_size=batch_size, | ||
batch_size=batch_size, | ||
optimizations=optimizations, | ||
max_seq_len=max_seq_len, | ||
page_params=page_params, | ||
dtype=ttnn.bfloat8_b, | ||
use_paged_kv_cache=paged_attention, | ||
paged_attention=paged_attention, | ||
) | ||
|
||
tokenizer = model_args.tokenizer | ||
generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer) | ||
|
||
num_tokens_generated_decode = [] | ||
|
@@ -447,8 +563,8 @@ def test_llama_demo_text( | |
|
||
user_done = [False] * batch_size # Keeps track when a user reaches EoD token | ||
|
||
# TODO Argmax on device is only supported for batch_size=1 | ||
argmax_on_device = False if (batch_size > 1 or sampling_params["temperature"] != 0) else True | ||
# TODO Argmax on device is only supported for batch_size=1 (per submesh) | ||
argmax_on_device = False if (batch_size // data_parallel > 1 or sampling_params["temperature"] != 0) else True | ||
|
||
# Initial positions | ||
current_pos = torch.tensor([decoding_pos[b] for b in range(batch_size)]) | ||
|
@@ -545,7 +661,7 @@ def test_llama_demo_text( | |
for i, (output, prompt) in enumerate(zip(all_outputs, input_prompts)): | ||
text = tokenizer.decode(output) | ||
prompt_including_assistant_tags = tokenizer.decode( | ||
model_args.encode_prompt(prompt, instruct=instruct) | ||
model_args[0].encode_prompt(prompt, instruct=instruct) | ||
) | ||
text_after_prompt = text.replace(prompt_including_assistant_tags, "", 1) | ||
if print_to_file: | ||
|
@@ -648,9 +764,9 @@ def test_llama_demo_text( | |
supported_models = ["Llama3.2-1B", "Llama3.2-3B", "Llama3.1-8B", "Llama3.2-11B", "Llama3.1-70B"] | ||
supported_devices = ["N150", "N300", "T3K", "TG"] | ||
|
||
tt_device_name = model_args.device_name | ||
tt_device_name = model_args[0].device_name | ||
|
||
if model_args.base_model_name in supported_models: | ||
if model_args[0].base_model_name in supported_models: | ||
assert tt_device_name in supported_devices, f"Device {tt_device_name} not supported" | ||
|
||
# Set the target times to first token for every combination of device and model | ||
|
@@ -679,7 +795,7 @@ def test_llama_demo_text( | |
"N300_Llama3.1-70B": 1050, # TODO Update target | ||
"T3K_Llama3.1-70B": 1050, # TODO Update target | ||
"TG_Llama3.1-70B": 1050, # TODO Update target | ||
}[f"{tt_device_name}_{model_args.base_model_name}"] | ||
}[f"{tt_device_name}_{model_args[0].base_model_name}"] | ||
|
||
# Set the target decode timesfor every combination of device and model | ||
target_decode_tok_s_u = { | ||
|
@@ -705,7 +821,7 @@ def test_llama_demo_text( | |
# | ||
"T3K_Llama3.1-70B": 20, # TODO Update target | ||
"TG_Llama3.1-70B": 20, # TODO Update target | ||
}[f"{tt_device_name}_{model_args.base_model_name}"] | ||
}[f"{tt_device_name}_{model_args[0].base_model_name}"] | ||
|
||
target_decode_tok_s = target_decode_tok_s_u * batch_size | ||
targets = { | ||
|
@@ -714,7 +830,7 @@ def test_llama_demo_text( | |
"decode_t/s/u": target_decode_tok_s_u, | ||
} | ||
else: | ||
logger.warning(f"Model {model_args.base_model_name} not does not have performance targets set") | ||
logger.warning(f"Model {model_args[0].base_model_name} not does not have performance targets set") | ||
targets = {} | ||
|
||
# Save benchmark data for CI dashboard | ||
|
@@ -752,9 +868,9 @@ def test_llama_demo_text( | |
benchmark_data.save_partial_run_json( | ||
profiler, | ||
run_type=f"{tt_device_name}-demo", | ||
ml_model_name=model_args.base_model_name, | ||
ml_model_name=model_args[0].base_model_name, | ||
ml_model_type="llm", | ||
num_layers=model_args.n_layers, | ||
num_layers=model_args[0].n_layers, | ||
batch_size=batch_size, | ||
input_sequence_length=max(prefill_lens), | ||
output_sequence_length=num_tokens_generated_decode[0], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to this,
max_num_blocks
now represents the max blocks per dp group right? Can it be renamed tomax_num_blocks_per_dp
for clarity?