|
19 | 19 | def device_sync(device):
|
20 | 20 | if "cuda" in device:
|
21 | 21 | torch.cuda.synchronize(device)
|
| 22 | + elif "xpu" in device: |
| 23 | + torch.xpu.synchronize(device) |
22 | 24 | elif ("cpu" in device) or ("mps" in device):
|
23 | 25 | pass
|
24 | 26 | else:
|
@@ -288,7 +290,10 @@ def main(
|
288 | 290 |
|
289 | 291 | for i in range(start, num_samples):
|
290 | 292 | if i==0:
|
291 |
| - torch.cuda.reset_peak_memory_stats() |
| 293 | + if "cuda" in device: |
| 294 | + torch.cuda.reset_peak_memory_stats() |
| 295 | + elif "xpu" in device: |
| 296 | + torch.xpu.reset_peak_memory_stats() |
292 | 297 | device_sync(device=device) # MKG
|
293 | 298 | if i >= 0 and interactive:
|
294 | 299 | prompt = input("What is your prompt? ")
|
@@ -318,8 +323,15 @@ def callback(x):
|
318 | 323 | if (i != num_samples - 1 or not profile):
|
319 | 324 | prof = contextlib.nullcontext()
|
320 | 325 | else:
|
321 |
| - torch.profiler._utils._init_for_cuda_graphs() |
322 |
| - prof = torch.profiler.profile() |
| 326 | + if "cuda" in device: |
| 327 | + torch.profiler._utils._init_for_cuda_graphs() |
| 328 | + prof = torch.profiler.profile() |
| 329 | + elif "xpu" in device: |
| 330 | + prof = torch.profiler.profile( |
| 331 | + activities=[ |
| 332 | + torch.profiler.ProfilerActivity.CPU, |
| 333 | + torch.profiler.ProfilerActivity.XPU], |
| 334 | + ) |
323 | 335 | with prof:
|
324 | 336 | y = generate(
|
325 | 337 | model,
|
@@ -369,7 +381,8 @@ def callback(x):
|
369 | 381 |
|
370 | 382 | tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
|
371 | 383 | bandwidth = model_size * tokpersec
|
372 |
| - mem = torch.cuda.max_memory_reserved() /1e9 |
| 384 | + max_memory_reserved = torch.cuda.max_memory_reserved() if "cuda" in device else torch.xpu.max_memory_reserved() |
| 385 | + mem = max_memory_reserved / 1e9 |
373 | 386 | print(f"Average tokens/sec: {tokpersec:.2f}")
|
374 | 387 | print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
|
375 | 388 | print(f"Peak Memory Usage: {mem:.02f} GB")
|
|
0 commit comments