Skip to content

Commit b8aa7ee

Browse files
authored
Merge pull request #132 from pytorch-labs/defaultdeviceargs1
Set device to CPU if CUDA not available in some arguments
2 parents 48328fb + 02dfe6f commit b8aa7ee

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

generate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def device_sync(device):
2626
torch._inductor.config.triton.unique_kernel_names = True
2727
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2828

29+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
2930

3031
# support running without installing as a package
3132
wd = Path(__file__).parent.parent.resolve()
@@ -206,7 +207,7 @@ def generate(
206207
}
207208
return seq, generate_stats
208209

209-
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
210+
def encode_tokens(tokenizer, string, bos=True, device=default_device):
210211
tokens = tokenizer.encode(string)
211212
if bos:
212213
tokens = [tokenizer.bos_id()] + tokens
@@ -259,7 +260,7 @@ def main(
259260
profile: Optional[Path] = None,
260261
draft_checkpoint_path: Optional[Path] = None,
261262
speculate_k: int = 5,
262-
device='cuda',
263+
device=default_device,
263264
) -> None:
264265
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
265266
"""
@@ -414,7 +415,7 @@ def callback(x):
414415
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
415416
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
416417
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
417-
parser.add_argument('--device', type=str, default="cuda", help='Device to use')
418+
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
418419

419420
args = parser.parse_args()
420421
main(

quantize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from model import Transformer
2121

22+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
23+
2224
##### Quantization Primitives ######
2325

2426
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
@@ -539,7 +541,7 @@ def quantize(
539541
percdamp: float = .01,
540542
blocksize: int = 128,
541543
label: str = '',
542-
device: str = 'cuda',
544+
device: str = default_device,
543545
) -> None:
544546
assert checkpoint_path.is_file(), checkpoint_path
545547

@@ -619,7 +621,7 @@ def quantize(
619621
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
620622
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
621623
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
622-
parser.add_argument('--device', type=str, default='cuda', help='device to use')
624+
parser.add_argument('--device', type=str, default=default_device, help='device to use')
623625

624626
args = parser.parse_args()
625627
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)

0 commit comments

Comments
 (0)