Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 hybrid implementation using submeshes #18777

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions models/demos/llama3/demo/multimodal_demo_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ def test_llama_multimodal_demo_chat(
logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
model_args, model, _ = create_multimodal_model(
mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len
)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)
generator = LlamaGenerator([model], [model_args], mesh_device, tokenizer=tokenizer, formatter=formatter)

# image understanding
dialogs = []
Expand Down
6 changes: 4 additions & 2 deletions models/demos/llama3/demo/multimodal_demo_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ def test_llama_multimodal_demo_text(
logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
model_args, model, _ = create_multimodal_model(
mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len
)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)
generator = LlamaGenerator([model], [model_args], mesh_device, tokenizer=tokenizer, formatter=formatter)

with open(IMG_PATH / "dog.jpg", "rb") as f:
img = PIL_Image.open(f).convert("RGB")
Expand Down
180 changes: 148 additions & 32 deletions models/demos/llama3/demo/simple_text_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def create_tt_model(
page_params,
dtype=ttnn.bfloat8_b,
use_paged_kv_cache=False,
state_dict=None,
):
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import TtModelArgs
Expand All @@ -120,9 +121,11 @@ def create_tt_model(
optimizations=optimizations,
max_seq_len=max_seq_len,
)
state_dict = tt_model_args.load_state_dict()

page_table = None
# Avoid loading state_dict for every DP model
if not state_dict:
state_dict = tt_model_args.load_state_dict()

paged_attention_config = None
tt_kv_cache = None

Expand All @@ -131,17 +134,6 @@ def create_tt_model(
block_size=page_params["page_block_size"],
max_num_blocks=page_params["page_max_num_blocks"],
)
# Implied shuffling of blocks
permutation = torch.randperm(paged_attention_config.max_num_blocks)
# Page table which maps virtual blocks to physical
reverse_permutation = torch.argsort(permutation)
page_table = reverse_permutation.reshape(
tt_model_args.max_batch_size, paged_attention_config.max_num_blocks // tt_model_args.max_batch_size
)
paged_attention_config = PagedAttentionConfig(
block_size=page_params["page_block_size"],
max_num_blocks=page_params["page_max_num_blocks"],
)

model = TtTransformer(
args=tt_model_args,
Expand All @@ -155,7 +147,82 @@ def create_tt_model(
if use_paged_kv_cache:
tt_kv_cache = [l.attention.layer_past for l in model.layers]

return tt_model_args, model, page_table, tt_kv_cache
return tt_model_args, model, tt_kv_cache, state_dict


def create_tt_page_table(max_batch_size, data_parallel, page_params, use_paged_kv_cache):
page_table = None
paged_attention_config = None

if use_paged_kv_cache:
paged_attention_config = PagedAttentionConfig(
block_size=page_params["page_block_size"],
max_num_blocks=page_params["page_max_num_blocks"],
)
# Implied shuffling of blocks
permutation = torch.randperm(paged_attention_config.max_num_blocks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to this, max_num_blocks now represents the max blocks per dp group right? Can it be renamed to max_num_blocks_per_dp for clarity?

# Page table which maps virtual blocks to physical
reverse_permutation = torch.argsort(permutation).repeat(data_parallel)
page_table = reverse_permutation.reshape(
max_batch_size, paged_attention_config.max_num_blocks // (max_batch_size // data_parallel)
)
paged_attention_config = PagedAttentionConfig(
block_size=page_params["page_block_size"],
max_num_blocks=page_params["page_max_num_blocks"],
)
return page_table


def prepare_generator_args(
num_devices,
data_parallel,
mesh_device,
instruct,
batch_size,
optimizations,
max_seq_len,
page_params,
paged_attention,
):
# Partition the mesh, singular model implemented for TP on 1xN mesh
submesh_devices = (
mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel))
if isinstance(mesh_device, ttnn.MeshDevice) and data_parallel > 1
else [mesh_device]
)
state_dict = None

# Hybrid requires a model per submesh
model_args = []
model = []
page_table = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused page_table var (overwritten below)

tt_kv_cache = []

for submesh in submesh_devices:
model_args_i, model_i, tt_kv_cache_i, state_dict = create_tt_model(
submesh,
instruct=instruct,
max_batch_size=batch_size // data_parallel,
optimizations=optimizations,
max_seq_len=max_seq_len,
page_params=page_params,
dtype=ttnn.bfloat8_b,
use_paged_kv_cache=paged_attention,
state_dict=state_dict,
)
model_args.append(model_args_i)
model.append(model_i)
tt_kv_cache.append(tt_kv_cache_i)

page_table = create_tt_page_table(
max_batch_size=batch_size,
data_parallel=data_parallel,
page_params=page_params,
use_paged_kv_cache=paged_attention,
)
# Host code, safe to reuse tokenizer from the 1st model
tokenizer = model_args[0].tokenizer
return model_args, model, page_table, tt_kv_cache, tokenizer


# List of supported Parameters for demo.py
Expand All @@ -174,7 +241,7 @@ def create_tt_model(
# optimization (LlamaOptimizations): Optimization level to use for the model (performance or accuracy)
# FAKE_DEVICE (str): Fake device to use for testing (N150, N300, T3K, TG). Usage: `export FAKE_DEVICE=N150`, will enable running a single-chip demo on a multi-chip system.
@pytest.mark.parametrize(
"input_prompts, instruct, repeat_batches, max_seq_len, batch_size, max_generated_tokens, paged_attention, page_params, sampling_params, stop_at_eos, ci_only",
"input_prompts, instruct, repeat_batches, max_seq_len, batch_size, max_generated_tokens, paged_attention, page_params, sampling_params, stop_at_eos, ci_only, data_parallel",
[
( # Batch-1 run (Latency) - single user, small prompt
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts
Expand All @@ -188,6 +255,7 @@ def create_tt_model(
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
True, # stop_at_eos
False, # ci_only
1,
),
( # Batch-32 run (Throughput) - 32 users, small prompt
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts
Expand All @@ -201,6 +269,7 @@ def create_tt_model(
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
True, # stop_at_eos
False, # ci_only
1, # data_parallel
),
( # Long-context run - Single user, long prompt (adapted to the model being used and architecture)
"models/demos/llama3/demo/sample_prompts/input_data_long_64k.json", # input_prompts
Expand All @@ -214,6 +283,7 @@ def create_tt_model(
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
True, # stop_at_eos
False, # ci_only
1, # data_parallel
),
( # Batch-1 run (Reasoning) - single user, small prompt, long thinking time
"models/demos/llama3/demo/input_data_questions_reasoning.json", # input_prompts
Expand All @@ -227,6 +297,7 @@ def create_tt_model(
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
False, # stop_at_eos
False, # ci_only
1, # data_parallel
),
( # CI Batch-1 run - Measures the performance of a single user over 4096 iterations
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts
Expand All @@ -240,6 +311,7 @@ def create_tt_model(
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
False, # stop_at_eos
True, # ci_only
1, # data_parallel
),
( # CI Batch-32 run - Measures the performance of a 32 users over 4096 iterations
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts
Expand All @@ -253,6 +325,35 @@ def create_tt_model(
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
False, # stop_at_eos
True, # ci_only
1, # data_parallel
),
( # Batch-1 run (Latency) - single user, small prompt
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts
True, # instruct mode
1, # repeat_batches
1024, # max_seq_len
1, # batch_size
200, # max_generated_tokens
True, # paged_attention
{"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
True, # stop_at_eos
False, # ci_only
4, # data_parallel
),
( # Batch-1 run (Latency) - single user, small prompt
"models/demos/llama3/demo/sample_prompts/input_data_questions_prefill_128.json", # input_prompts
True, # instruct mode
1, # repeat_batches
1024, # max_seq_len
1, # batch_size
200, # max_generated_tokens
True, # paged_attention
{"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params
{"temperature": 0, "top_p": 0.08}, # sampling_params (argmax)
True, # stop_at_eos
False, # ci_only
8, # data_parallel
),
],
ids=[
Expand All @@ -262,6 +363,8 @@ def create_tt_model(
"reasoning-1", # reasoning
"ci-1", # CI batch 1
"ci-32", # CI batch 32
"batch-1-DP-4", # DP 4 latency
"batch-1-DP-8", # DP 8 latency
Comment on lines +366 to +367
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a batch 32 + DP test?

],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -297,6 +400,7 @@ def test_llama_demo_text(
use_program_cache,
is_ci_env,
ci_only,
data_parallel,
reset_seeds,
request,
):
Expand Down Expand Up @@ -335,6 +439,19 @@ def test_llama_demo_text(
]: # If the flag is provided, use it. Take an int instead of bool due to parser limitations
stop_at_eos = request.config.getoption("--stop_at_eos")

num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1
batch_size *= data_parallel # input batch_size is interpreted as size per DP group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can batch_size be renamed for clarity (here and throughout the demo)? e.g. global_batch_size for batch_size * data_parallel


# uneven split of devices per DP group not supported
if data_parallel > num_devices or num_devices % data_parallel != 0:
pytest.skip(f"Invalid number of DP groups: {data_parallel}, for {num_devices} devices")

llama_dir = os.getenv("LLAMA_DIR")
if is_ci_env and num_devices == 32 and (data_parallel > 4 or (data_parallel == 4 and "3.1-70B" not in llama_dir)):
pytest.skip("CI runs only Llama3 70b DP = 4, TP = 8 on TG")
if is_ci_env and num_devices == 8 and data_parallel > 1 and not ("3.2-1B" in llama_dir or "3.1-8B" in llama_dir):
pytest.skip("CI runs only hybrid Llama3 1b and 8b on T3K")
Comment on lines +452 to +453
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about 3B?


if not stop_at_eos:
logger.info(f"The decode generation will only stop at the max_generated_tokens limit == {max_generated_tokens}")

Expand Down Expand Up @@ -366,18 +483,17 @@ def test_llama_demo_text(
for i in range(repeat_batches):
repeat_batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))])

model_args, model, page_table, tt_kv_cache = create_tt_model(
mesh_device,
model_args, model, page_table, tt_kv_cache, tokenizer = prepare_generator_args(
num_devices=num_devices,
data_parallel=data_parallel,
mesh_device=mesh_device,
instruct=instruct,
max_batch_size=batch_size,
batch_size=batch_size,
optimizations=optimizations,
max_seq_len=max_seq_len,
page_params=page_params,
dtype=ttnn.bfloat8_b,
use_paged_kv_cache=paged_attention,
paged_attention=paged_attention,
)

tokenizer = model_args.tokenizer
generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer)

num_tokens_generated_decode = []
Expand Down Expand Up @@ -447,8 +563,8 @@ def test_llama_demo_text(

user_done = [False] * batch_size # Keeps track when a user reaches EoD token

# TODO Argmax on device is only supported for batch_size=1
argmax_on_device = False if (batch_size > 1 or sampling_params["temperature"] != 0) else True
# TODO Argmax on device is only supported for batch_size=1 (per submesh)
argmax_on_device = False if (batch_size // data_parallel > 1 or sampling_params["temperature"] != 0) else True

# Initial positions
current_pos = torch.tensor([decoding_pos[b] for b in range(batch_size)])
Expand Down Expand Up @@ -545,7 +661,7 @@ def test_llama_demo_text(
for i, (output, prompt) in enumerate(zip(all_outputs, input_prompts)):
text = tokenizer.decode(output)
prompt_including_assistant_tags = tokenizer.decode(
model_args.encode_prompt(prompt, instruct=instruct)
model_args[0].encode_prompt(prompt, instruct=instruct)
)
text_after_prompt = text.replace(prompt_including_assistant_tags, "", 1)
if print_to_file:
Expand Down Expand Up @@ -648,9 +764,9 @@ def test_llama_demo_text(
supported_models = ["Llama3.2-1B", "Llama3.2-3B", "Llama3.1-8B", "Llama3.2-11B", "Llama3.1-70B"]
supported_devices = ["N150", "N300", "T3K", "TG"]

tt_device_name = model_args.device_name
tt_device_name = model_args[0].device_name

if model_args.base_model_name in supported_models:
if model_args[0].base_model_name in supported_models:
assert tt_device_name in supported_devices, f"Device {tt_device_name} not supported"

# Set the target times to first token for every combination of device and model
Expand Down Expand Up @@ -679,7 +795,7 @@ def test_llama_demo_text(
"N300_Llama3.1-70B": 1050, # TODO Update target
"T3K_Llama3.1-70B": 1050, # TODO Update target
"TG_Llama3.1-70B": 1050, # TODO Update target
}[f"{tt_device_name}_{model_args.base_model_name}"]
}[f"{tt_device_name}_{model_args[0].base_model_name}"]

# Set the target decode timesfor every combination of device and model
target_decode_tok_s_u = {
Expand All @@ -705,7 +821,7 @@ def test_llama_demo_text(
#
"T3K_Llama3.1-70B": 20, # TODO Update target
"TG_Llama3.1-70B": 20, # TODO Update target
}[f"{tt_device_name}_{model_args.base_model_name}"]
}[f"{tt_device_name}_{model_args[0].base_model_name}"]

target_decode_tok_s = target_decode_tok_s_u * batch_size
targets = {
Expand All @@ -714,7 +830,7 @@ def test_llama_demo_text(
"decode_t/s/u": target_decode_tok_s_u,
}
else:
logger.warning(f"Model {model_args.base_model_name} not does not have performance targets set")
logger.warning(f"Model {model_args[0].base_model_name} not does not have performance targets set")
targets = {}

# Save benchmark data for CI dashboard
Expand Down Expand Up @@ -752,9 +868,9 @@ def test_llama_demo_text(
benchmark_data.save_partial_run_json(
profiler,
run_type=f"{tt_device_name}-demo",
ml_model_name=model_args.base_model_name,
ml_model_name=model_args[0].base_model_name,
ml_model_type="llm",
num_layers=model_args.n_layers,
num_layers=model_args[0].n_layers,
batch_size=batch_size,
input_sequence_length=max(prefill_lens),
output_sequence_length=num_tokens_generated_decode[0],
Expand Down
Loading
Loading