-
Notifications
You must be signed in to change notification settings - Fork 329
/
Copy pathtest_generate.py
307 lines (259 loc) · 9.4 KB
/
test_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import Optional
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._tensor import Replicate
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor
from torchtitan.models import model_name_to_tokenizer
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec
from torchtitan.utils import device_module, device_type
# support running w/o installing as package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate._generation import generate
def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh):
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
},
)
for _, transformer_block in model.layers.items():
layer_plan = {
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
@record
def test_generate(
config_path: str,
checkpoint_path: str,
prompt: str,
*,
temperature: float = 1.0,
max_new_tokens: int = 32,
batch_size: int = 1,
top_k: Optional[int] = None,
seed: Optional[int] = None,
deterministic: bool = False,
):
init_logger()
color = utils.Color
# Load configuration from toml file
config = JobConfig()
config.parse_args([f"--job.config_file={config_path}"])
config._validate_config()
if len(args.prompt) == 0:
logger.warning(
"The input prompt is empty, model will respond from a empty sequence."
)
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"{device_type}:{local_rank}")
device_module.set_device(device)
device_memory_monitor = build_device_memory_monitor()
train_spec = get_train_spec(config.model.name)
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")
# Tokenizer setup
tokenizer = build_tokenizer(
model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path
)
model_config = train_spec.config[config.model.flavor]
model_config.norm_type = config.model.norm_type
model_config.max_seq_len = config.training.seq_len
model_config.vocab_size = tokenizer.n_words
model_cls = train_spec.cls
init_device = "meta" if world_size > 1 else device
with torch.device(init_device):
logger.info(f"Init model on init_device: {init_device}")
model = model_cls.from_model_args(model_config)
world_mesh = None
# Init distributed env
if world_size > 1:
utils.init_distributed(config)
parallel_dims = ParallelDims(
dp_replicate=1,
dp_shard=-1,
cp=1,
tp=world_size,
pp=1,
world_size=world_size,
enable_loss_parallel=False,
)
# Build world mesh for parallelism
world_mesh = parallel_dims.build_mesh(device_type=device_type)
# apply_tp (with Sequence Parallel) on unevenly sharded
# sequences would require https://github.com/pytorch/torchtitan/pull/686
apply_tp_minus_sp(model, world_mesh["tp"])
utils.set_determinism(world_mesh, device, seed, deterministic)
# materalize model
model.to_empty(device=device_type)
model.eval()
state_dict = {"model": model.state_dict()}
# Checkpoint Loading
begin = time.monotonic()
logger.info(f"Loading chkpt at: {checkpoint_path}")
dcp.load(state_dict, checkpoint_id=checkpoint_path)
logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.")
device_mem_stats = device_memory_monitor.get_peak_stats()
logger.info(
f"{device_type.upper()} memory usage for model: "
f"{device_mem_stats.max_reserved_gib:.2f}GiB"
f"({device_mem_stats.max_reserved_pct:.2f}%)"
)
# Tokenize prompt and repeat batch_size times
input_ids = (
(
torch.tensor(
tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.long
)
.view(1, -1)
.repeat(batch_size, 1)
)
).to(device_type)
device_memory_monitor.reset_peak_stats()
# Run generation
t0 = time.monotonic()
responses = generate(
model,
input_ids,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
seed=seed,
)
t1 = time.monotonic()
elapsed_sec = t1 - t0
# Post process
B, T = responses.size() # B: batch_size, T: total seq length
input_n_tokens = input_ids.size(1)
generated_n_tokens = T - input_n_tokens # == max_new_tokens
if local_rank == 0:
logger.info(f"Generation completed in {elapsed_sec:.2f} seconds.")
r, b = color.red, color.blue
output_data = {
"metadata": {},
"responses": [],
}
for i, tokens in enumerate(responses):
inp_tok = tokens[:input_n_tokens].tolist()
out_tok = tokens[input_n_tokens:].tolist()
input_text = tokenizer.decode(inp_tok)
output_text = tokenizer.decode(out_tok)
_data = {
"response_idx": i,
"input_text": input_text,
"output_text": output_text,
}
output_data["responses"].append(_data)
logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}")
device_mem_stats = device_memory_monitor.get_peak_stats()
output_data["metadata"] = {
"generated_n_tokens": generated_n_tokens,
"input_n_tokens": input_n_tokens,
"generation_time_sec": elapsed_sec,
"tokens_per_sec": (B * T) / elapsed_sec,
"batch_size": B,
"seed": seed,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()),
"memory/max_active(GiB)": device_mem_stats.max_active_gib,
"memory/max_active(%)": device_mem_stats.max_active_pct,
"memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
"memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
"memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
"memory/num_ooms": device_mem_stats.num_ooms,
"world_size": world_size,
"torch_version": torch.__version__,
}
if args.out:
print(json.dumps(output_data, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test generation")
parser.add_argument(
"--config", type=str, required=True, help="TOML config file path (required)"
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Checkpoint path to load (required)",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature. Default is 1.0",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=32,
help="Max number of tokens to generate. Default is 32",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Number of samples to run in batch"
)
parser.add_argument(
"--top_k", type=int, help="Prune to select from top_k probabilities. Optional"
)
parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
parser.add_argument(
"--deterministic",
action="store_true",
help="Use deterministic algorithms wherever possible, may be slower",
)
parser.add_argument("--prompt", type=str, default="", help="Input prompt")
parser.add_argument(
"--out",
action="store_true",
default=False,
help="If specified, prints the report to stdout. Defaults to no output.",
)
args = parser.parse_args()
test_generate(
config_path=args.config,
checkpoint_path=args.checkpoint,
prompt=args.prompt,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
top_k=args.top_k,
seed=args.seed,
deterministic=args.deterministic,
)
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()