Skip to content

Commit 2afbf1c

Browse files
committed
feat: Make tokenizer add_special_tokens option configurable
In particular so that it can be disabled for chat/instruct models where an explicit template is used that already includes these tokens. (for example the leading <s> token added by llama and mixtral tokenizers) Signed-off-by: Nick Hill <[email protected]>
1 parent f7d3c5f commit 2afbf1c

File tree

8 files changed

+35
-10
lines changed

8 files changed

+35
-10
lines changed

launcher/src/main.rs

+9
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct Args {
8989
// Default for default_include_stop_seqs is true for now, for backwards compatibility
9090
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
9191
default_include_stop_seqs: bool,
92+
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
93+
add_special_tokens: bool,
9294
}
9395

9496
fn main() -> ExitCode {
@@ -237,6 +239,7 @@ fn main() -> ExitCode {
237239
args.max_new_tokens,
238240
args.max_batch_size,
239241
args.batch_safety_margin,
242+
args.add_special_tokens,
240243
args.shard_uds_path,
241244
args.cuda_process_memory_fraction,
242245
cuda_alloc_conf,
@@ -307,6 +310,8 @@ fn main() -> ExitCode {
307310
format!("{}-0", args.shard_uds_path),
308311
"--tokenizer-path".to_string(),
309312
tokenizer_path,
313+
"--add-special-tokens".to_string(),
314+
args.add_special_tokens.to_string(),
310315
];
311316

312317
if let Some(path) = args.tls_key_path {
@@ -541,6 +546,7 @@ fn shard_manager(
541546
max_new_tokens: usize,
542547
max_batch_size: usize,
543548
batch_safety_margin: usize,
549+
add_special_tokens: bool,
544550
uds_path: String,
545551
cuda_process_memory_fraction: f32,
546552
cuda_alloc_conf: Option<&str>,
@@ -627,6 +633,9 @@ fn shard_manager(
627633
}
628634
}
629635

636+
// Add special tokens when tokenizing (e.g. leading <s> with llama tokenizer)
637+
env.push(("ADD_SPECIAL_TOKENS".into(), add_special_tokens.to_string().into()));
638+
630639
// Torch Distributed / DeepSpeed Env vars
631640
env.push(("RANK".into(), rank.to_string().into()));
632641
env.push(("LOCAL_RANK".into(), rank.to_string().into()));

router/src/main.rs

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ struct Args {
5050
output_special_tokens: bool,
5151
#[clap(long, env)]
5252
default_include_stop_seqs: bool,
53+
#[clap(default_value = "true", long, env, action = clap::ArgAction::Set)]
54+
add_special_tokens: bool,
5355
}
5456

5557
fn main() -> Result<(), std::io::Error> {
@@ -149,6 +151,7 @@ fn main() -> Result<(), std::io::Error> {
149151
tls_client_ca_cert: args.tls_client_ca_cert_path,
150152
output_special_tokens: args.output_special_tokens,
151153
default_include_stop_seqs: args.default_include_stop_seqs,
154+
add_special_tokens: args.add_special_tokens,
152155
})
153156
.await;
154157
Ok(())

router/src/server.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ pub struct ServerRunArgs {
247247
pub tls_client_ca_cert: Option<String>,
248248
pub output_special_tokens: bool,
249249
pub default_include_stop_seqs: bool,
250+
pub add_special_tokens: bool,
250251
}
251252

252253
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
@@ -337,7 +338,9 @@ async fn do_run<B: BatchType>(
337338
args.max_sequence_length - 1
338339
};
339340

340-
let tokenizers = AsyncTokenizer::new(&args.tokenizer, args.tokenization_workers);
341+
let tokenizers = AsyncTokenizer::new(
342+
&args.tokenizer, args.add_special_tokens, args.tokenization_workers
343+
);
341344

342345
// Create state
343346
let generation_health = Arc::new(AtomicBool::new(false));

router/src/tokenizer.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ impl Debug for AsyncTokenizer {
2323

2424
/// Uses pool of tokenizer threads to provide async tokenization methods
2525
impl AsyncTokenizer {
26-
pub(crate) fn new(tokenizer: &Tokenizer, workers: usize) -> Self {
26+
pub(crate) fn new(tokenizer: &Tokenizer, add_special_tokens: bool, workers: usize) -> Self {
2727
let (sender, receiver) = flume::unbounded();
2828
for _ in 0..workers {
2929
let tokenizer = tokenizer.clone();
3030
let receiver = receiver.clone();
31-
tokio::task::spawn_blocking(move || tokenization_worker(tokenizer, receiver));
31+
tokio::task::spawn_blocking(
32+
move || tokenization_worker(tokenizer, receiver, add_special_tokens)
33+
);
3234
}
3335
Self { sender }
3436
}
@@ -50,10 +52,12 @@ impl AsyncTokenizer {
5052
}
5153
}
5254

53-
fn tokenization_worker(tokenizer: Tokenizer, receiver: Receiver<TokenizationRequest>) {
55+
fn tokenization_worker(
56+
tokenizer: Tokenizer, receiver: Receiver<TokenizationRequest>, add_special_tokens: bool
57+
) {
5458
while let Ok((input, with_encoding, sender)) = receiver.recv() {
5559
let result = tokenizer
56-
.encode(&input[..], true)
60+
.encode(&input[..], add_special_tokens)
5761
.map(|encoding| (input, encoding.len(), with_encoding.then_some(encoding)));
5862
sender.send(result).unwrap_or_default();
5963
}

server/text_generation_server/models/causal_lm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
99
from typing import Optional, Tuple, List, Type, Union, Any
1010

11-
from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8
11+
from text_generation_server.models.model import Model, ADD_SPECIAL_TOKENS, CUDA_PAD_TO_MULT_OF_8
1212
from text_generation_server.models.types import Batch, GenerateError
1313
from text_generation_server.pb import generate_pb2
1414
from text_generation_server.prompt_cache import PrefixCache
@@ -143,6 +143,7 @@ def from_pb(
143143
truncation=True,
144144
max_length=tokenize_length,
145145
return_token_type_ids=False,
146+
add_special_tokens=ADD_SPECIAL_TOKENS,
146147
).to(device)
147148
all_input_ids = tokenized_inputs["input_ids"]
148149

server/text_generation_server/models/flash_causal_lm.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import torch
66
import torch.distributed
77

8-
from torch.nn import functional as F
9-
108
from dataclasses import dataclass
119
from transformers import PreTrainedTokenizerBase
1210
from typing import Optional, Tuple, List, Type, Union, Any
1311

1412
from text_generation_server.inference_engine import get_inference_engine_class
1513
from text_generation_server.models import Model
14+
from text_generation_server.models.model import ADD_SPECIAL_TOKENS
1615

1716
from text_generation_server.models.types import Batch, GenerateError
1817
from text_generation_server.pb import generate_pb2
@@ -123,7 +122,11 @@ def from_pb(
123122
# return as lists to avoid unnecessary padding;
124123
# sequences will be concatenated across the batch
125124
batch_tokenized_inputs = tokenizer(
126-
batch_inputs, truncation=True, max_length=max_seqlen, return_token_type_ids=False
125+
batch_inputs,
126+
truncation=True,
127+
max_length=max_seqlen,
128+
return_token_type_ids=False,
129+
add_special_tokens=ADD_SPECIAL_TOKENS,
127130
)["input_ids"]
128131

129132
# Process inputs to generate the needed tensors

server/text_generation_server/models/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
CUDA_PAD_TO_MULT_OF_8 = os.getenv("CUDA_PAD_TO_MULT_OF_8", "true").lower() != "false"
2525
PT2_COMPILE = os.getenv("PT2_COMPILE", "false").lower() != "false"
26+
ADD_SPECIAL_TOKENS = os.getenv("ADD_SPECIAL_TOKENS", "true").lower() != "false" # defaults to true
2627

2728
if PT2_COMPILE:
2829
import torch._dynamo

server/text_generation_server/models/seq2seq_lm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from transformers.modeling_outputs import BaseModelOutput
1212

13-
from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE
13+
from text_generation_server.models.model import Model, ADD_SPECIAL_TOKENS, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE
1414
from text_generation_server.models.types import Batch, GenerateError
1515
from text_generation_server.pb import generate_pb2
1616
from text_generation_server.prompt_cache import PrefixCache
@@ -148,6 +148,7 @@ def from_pb(
148148
truncation=True,
149149
max_length=tokenize_length,
150150
return_token_type_ids=False,
151+
add_special_tokens=ADD_SPECIAL_TOKENS,
151152
).to(device)
152153
input_ids = tokenized_inputs["input_ids"]
153154
attention_mask = tokenized_inputs["attention_mask"]

0 commit comments

Comments
 (0)