@@ -26,6 +26,7 @@ def device_sync(device):
26
26
torch ._inductor .config .triton .unique_kernel_names = True
27
27
torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28
28
29
+ default_device = 'cuda' if torch .cuda .is_available () else 'cpu'
29
30
30
31
# support running without installing as a package
31
32
wd = Path (__file__ ).parent .parent .resolve ()
@@ -206,7 +207,7 @@ def generate(
206
207
}
207
208
return seq , generate_stats
208
209
209
- def encode_tokens (tokenizer , string , bos = True , device = 'cuda' ):
210
+ def encode_tokens (tokenizer , string , bos = True , device = default_device ):
210
211
tokens = tokenizer .encode (string )
211
212
if bos :
212
213
tokens = [tokenizer .bos_id ()] + tokens
@@ -259,7 +260,7 @@ def main(
259
260
profile : Optional [Path ] = None ,
260
261
draft_checkpoint_path : Optional [Path ] = None ,
261
262
speculate_k : int = 5 ,
262
- device = 'cuda' ,
263
+ device = default_device ,
263
264
) -> None :
264
265
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
265
266
"""
@@ -414,7 +415,7 @@ def callback(x):
414
415
parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
415
416
parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
416
417
parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
417
- parser .add_argument ('--device' , type = str , default = "cuda" , help = 'Device to use' )
418
+ parser .add_argument ('--device' , type = str , default = default_device , help = 'Device to use' )
418
419
419
420
args = parser .parse_args ()
420
421
main (
0 commit comments