Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing <|begin_of_text|> Token in Llama3Tokenizer #2300

Open
seungjun-green opened this issue Jan 27, 2025 · 3 comments
Open

Missing <|begin_of_text|> Token in Llama3Tokenizer #2300

seungjun-green opened this issue Jan 27, 2025 · 3 comments

Comments

@seungjun-green
Copy link

🐛 Describe the bug

When initializing the Llama3Tokenizer with a special_tokens_map.json file downloaded from Hugging Face, a ValueError is raised because the <|begin_of_text|> token is missing. The Llama3Tokenizer implementation appears to require this token by default, but it is not included in the special_tokens_map.json file provided by Hugging Face.

Steps to Reproduce

  1. Install the required packages:

    !pip install torchtune
    !pip install torchao
  2. Download the model and tokenizer from Hugging Face:

    !tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf-token <API_KEY>
  3. Attempt to initialize the Llama3Tokenizer:

    from torchtune.models.llama3 import Llama3Tokenizer
    
    tokenizer = Llama3Tokenizer(
        path="/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model",
        special_tokens="/tmp/Llama-3.2-1B-Instruct/special_tokens_map.json"
    )

Error Message

The following error occurs during initialization:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-21-b65979cf00ad> in <cell line: 0>()
----> 1 tokenizer = Llama3Tokenizer(
      2     path="/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model",
      3     special_tokens="/tmp/Llama-3.2-1B-Instruct/special_tokens_map.json"
      4 )

1 frames
/usr/local/lib/python3.11/dist-packages/torchtune/models/llama3/_tokenizer.py in _validate_special_tokens(self)
    137         ]:
    138             if token not in self.special_tokens:
--> 139                 raise ValueError(f"{token} missing from special_tokens")
    140 
    141     def _remove_special_tokens(self, text: str) -> str:

ValueError: <|begin_of_text|> missing from special_tokens

Versions

Collecting environment information...
PyTorch version: 2.5.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.85+-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 2
On-line CPU(s) list: 0,1
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 79
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
Stepping: 0
BogoMIPS: 4399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32 KiB (1 instance)
L1i cache: 32 KiB (1 instance)
L2 cache: 256 KiB (1 instance)
L3 cache: 55 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Mmio stale data: Vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable (Syscall hardening enabled)
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] nvtx==0.2.10
[pip3] optree==0.14.0
[pip3] pynvjitlink-cu12==0.4.0
[pip3] torch==2.5.1+cu121
[pip3] torchao==0.8.0
[pip3] torchaudio==2.5.1+cu121
[pip3] torchsummary==1.5.1
[pip3] torchtune==0.5.0
[pip3] torchvision==0.20.1+cu121
[pip3] triton==3.1.0
[conda] Could not collect

@malfet
Copy link

malfet commented Jan 27, 2025

Moving to touchtune (as it does not sounds like a PyTorch issue)

@malfet malfet transferred this issue from pytorch/pytorch Jan 27, 2025
@RdoubleA
Copy link
Contributor

Hi @seungjun-green, are you trying to use a custom special_tokens_map or the default one? If the default one, you do not need to specify the special tokens path, and it will find all the correct special tokens.

Either way, this is indeed a bug. The <|begin_of_text|> is in a different config file in the HF repo, so we should either 1) assume the default when passing in a special tokens map, or 2) require all tokenize related jsons be passed in.

@seungjun-green
Copy link
Author

seungjun-green commented Jan 28, 2025

the tokenizer.special_tokens does include special tokens such as <|begin_of_text|> , <|end_of_text|> or <|finetune_right_pad_id|> But if I try encoding special token such as <|finetune_right_pad_id|>, it tokenier does not recoginize as a special token and just do the encoding treating it as a just random string. So as Llama3 tokenizer does not have the 'add_special_tokens' method I tried adding it by 'special_tokens' argument but got tubmled up with that bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants