diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5fb905dbf9..1076bf1aeb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -19,6 +19,8 @@ def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: @@ -288,7 +290,10 @@ def main( for i in range(start, num_samples): if i==0: - torch.cuda.reset_peak_memory_stats() + if "cuda" in device: + torch.cuda.reset_peak_memory_stats() + elif "xpu" in device: + torch.xpu.reset_peak_memory_stats() device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") @@ -318,8 +323,15 @@ def callback(x): if (i != num_samples - 1 or not profile): prof = contextlib.nullcontext() else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() + if "cuda" in device: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + elif "xpu" in device: + prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU], + ) with prof: y = generate( model, @@ -369,7 +381,8 @@ def callback(x): tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() /1e9 + max_memory_reserved = torch.cuda.max_memory_reserved() if "cuda" in device else torch.xpu.max_memory_reserved() + mem = max_memory_reserved / 1e9 print(f"Average tokens/sec: {tokpersec:.2f}") print(f"Average Bandwidth: {bandwidth:.02f} GB/s") print(f"Peak Memory Usage: {mem:.02f} GB")