Skip to content

Commit

Permalink
extensions/openai: +Array input (batched) , +Fixes (#3309)
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic authored Aug 2, 2023
1 parent 40038fd commit 9ae0eab
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 63 deletions.
3 changes: 2 additions & 1 deletion extensions/openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ print(text)
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
Expand Down Expand Up @@ -204,6 +204,7 @@ Some hacky mappings:
| 1.0 | typical_p | hardcoded to 1.0 |
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
| messages.name | - | not supported yet |
| suffix | - | not supported yet |
| user | - | not supported yet |
| functions/function_call | - | function calls are not supported yet |

Expand Down
127 changes: 66 additions & 61 deletions extensions/openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> to
top_tokens = [ decode(tok) for tok in top_indices[0] ]
top_probs = [ float(x) for x in top_values[0] ]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})")
debug_msg(repr(self))
return logits

def __repr__(self):
Expand All @@ -63,7 +63,8 @@ def convert_logprobs_to_tiktoken(model, logprobs):
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
return logprobs
# return logprobs
return logprobs


def marshal_common_params(body):
Expand Down Expand Up @@ -271,16 +272,16 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
req_params['max_new_tokens'] = req_params['truncation_length']

# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']

# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count

stopping_strings = req_params.pop('stopping_strings', [])

# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)

answer = ''
Expand Down Expand Up @@ -347,7 +348,7 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = req_params['truncation_length']

# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']

# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
Expand Down Expand Up @@ -441,16 +442,9 @@ def completions(body: dict, is_legacy: bool = False):
if not prompt_str in body:
raise InvalidRequestError("Missing required input", param=prompt_str)

prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]

# common params
req_params = marshal_common_params(body)
Expand All @@ -460,59 +454,75 @@ def completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])

token_count = len(encode(prompt)[0])
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0

if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]

req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count

# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)

answer = ''
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''

for a in generator:
answer = a
for a in generator:
answer = a

# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]

completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"

respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}

resp_list_data.extend([respi])

resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"text": answer,
}],
resp_list: resp_list_data,
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}

if logprob_proc and logprob_proc.token_alternatives:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
resp[resp_list][0]["logprobs"] = None

return resp


Expand Down Expand Up @@ -550,6 +560,10 @@ def stream_completions(body: dict, is_legacy: bool = False):
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])

token_count = len(encode(prompt)[0])

Expand All @@ -558,9 +572,6 @@ def stream_completions(body: dict, is_legacy: bool = False):
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)

req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])

def text_streaming_chunk(content):
# begin streaming
chunk = {
Expand All @@ -572,22 +583,16 @@ def text_streaming_chunk(content):
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
chunk[resp_list][0]["logprobs"] = None

return chunk

yield text_streaming_chunk('')

# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)

answer = ''
Expand Down
2 changes: 1 addition & 1 deletion extensions/openai/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def do_GET(self):
resp = OAImodels.list_models(is_legacy)
else:
model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info()
resp = OAImodels.model_info(model_name)

self.return_json(resp)

Expand Down

0 comments on commit 9ae0eab

Please sign in to comment.