Skip to content

Commit de4ac30

Browse files
committed
feat(xpu): enable XPU for Llama
Signed-off-by: dbyoung18 <[email protected]>
1 parent c0a81f9 commit de4ac30

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

torchao/_models/llama/generate.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
def device_sync(device):
2020
if "cuda" in device:
2121
torch.cuda.synchronize(device)
22+
elif "xpu" in device:
23+
torch.xpu.synchronize(device)
2224
elif ("cpu" in device) or ("mps" in device):
2325
pass
2426
else:
@@ -288,7 +290,10 @@ def main(
288290

289291
for i in range(start, num_samples):
290292
if i==0:
291-
torch.cuda.reset_peak_memory_stats()
293+
if "cuda" in device:
294+
torch.cuda.reset_peak_memory_stats()
295+
elif "xpu" in device:
296+
torch.xpu.reset_peak_memory_stats()
292297
device_sync(device=device) # MKG
293298
if i >= 0 and interactive:
294299
prompt = input("What is your prompt? ")
@@ -318,8 +323,15 @@ def callback(x):
318323
if (i != num_samples - 1 or not profile):
319324
prof = contextlib.nullcontext()
320325
else:
321-
torch.profiler._utils._init_for_cuda_graphs()
322-
prof = torch.profiler.profile()
326+
if "cuda" in device:
327+
torch.profiler._utils._init_for_cuda_graphs()
328+
prof = torch.profiler.profile()
329+
elif "xpu" in device:
330+
prof = torch.profiler.profile(
331+
activities=[
332+
torch.profiler.ProfilerActivity.CPU,
333+
torch.profiler.ProfilerActivity.XPU],
334+
)
323335
with prof:
324336
y = generate(
325337
model,
@@ -369,7 +381,8 @@ def callback(x):
369381

370382
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
371383
bandwidth = model_size * tokpersec
372-
mem = torch.cuda.max_memory_reserved() /1e9
384+
max_memory_reserved = torch.cuda.max_memory_reserved() if "cuda" in device else torch.xpu.max_memory_reserved()
385+
mem = max_memory_reserved / 1e9
373386
print(f"Average tokens/sec: {tokpersec:.2f}")
374387
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
375388
print(f"Peak Memory Usage: {mem:.02f} GB")

0 commit comments

Comments
 (0)