diff --git a/tiktoken-rs/Cargo.toml b/tiktoken-rs/Cargo.toml index 6f8c99e..feca6cc 100644 --- a/tiktoken-rs/Cargo.toml +++ b/tiktoken-rs/Cargo.toml @@ -5,7 +5,7 @@ description = "Library for encoding and decoding with the tiktoken library in Ru include = ["assets/**/*", "src/**/*", "README.md"] edition = "2021" authors = ["Roger Zurawicki "] -rust-version = "1.61.0" +rust-version = "1.66.0" keywords = ["openai", "ai", "gpt", "bpe"] homepage = "https://github.com/zurawiki/tiktoken-rs" repository = "https://github.com/zurawiki/tiktoken-rs" @@ -18,14 +18,14 @@ debug = 1 [dependencies] anyhow = "1.0.76" -async-openai = { version = "0.14.2", optional = true } +async-openai = { version = "0.31.1", optional = true, features = ["chat-completion-types"] } base64 = "0.22.0" bstr = "1.6.2" dhat = { version = "0.3.2", optional = true } -fancy-regex = "0.13.0" +fancy-regex = "0.16.2" lazy_static = "1.4.0" regex = "1.10.3" -rustc-hash = "1.1.0" +rustc-hash = "2" [features] async-openai = ["dep:async-openai"] diff --git a/tiktoken-rs/README.md b/tiktoken-rs/README.md index bb314ad..e9f267c 100644 --- a/tiktoken-rs/README.md +++ b/tiktoken-rs/README.md @@ -74,31 +74,34 @@ println!("max_tokens: {}", max_tokens); Need to enable the `async-openai` feature in your `Cargo.toml` file. ```rust -use tiktoken_rs::async_openai::get_chat_completion_max_tokens; -use async_openai::types::{ChatCompletionRequestMessage, Role}; - -let messages = vec![ - ChatCompletionRequestMessage { - content: Some("You are a helpful assistant that only speaks French.".to_string()), - role: Role::System, - name: None, - function_call: None, - }, - ChatCompletionRequestMessage { - content: Some("Hello, how are you?".to_string()), - role: Role::User, - name: None, - function_call: None, - }, - ChatCompletionRequestMessage { - content: Some("Parlez-vous francais?".to_string()), - role: Role::System, - name: None, - function_call: None, - }, -]; -let max_tokens = get_chat_completion_max_tokens("o1-mini", &messages).unwrap(); -println!("max_tokens: {}", max_tokens); +#[cfg(feature = "async-openai")] +{ + use tiktoken_rs::async_openai::get_chat_completion_max_tokens; + use async_openai::types::chat::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, + ChatCompletionRequestUserMessageArgs, + }; + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessageArgs::default() + .content("You are a helpful assistant that only speaks French.") + .build() + .unwrap() + .into(), + ChatCompletionRequestUserMessageArgs::default() + .content("Hello, how are you?") + .build() + .unwrap() + .into(), + ChatCompletionRequestSystemMessageArgs::default() + .content("Parlez-vous francais?") + .build() + .unwrap() + .into(), + ]; + let max_tokens = get_chat_completion_max_tokens("o1-mini", &messages).unwrap(); + println!("max_tokens: {}", max_tokens); +} ``` `tiktoken` supports these encodings used by OpenAI models: diff --git a/tiktoken-rs/src/api.rs b/tiktoken-rs/src/api.rs index 0d837d2..c6cf178 100644 --- a/tiktoken-rs/src/api.rs +++ b/tiktoken-rs/src/api.rs @@ -368,9 +368,91 @@ mod tests { #[cfg(feature = "async-openai")] pub mod async_openai { use anyhow::Result; + use async_openai::types::chat as aoc; - impl From<&async_openai::types::FunctionCall> for super::FunctionCall { - fn from(f: &async_openai::types::FunctionCall) -> Self { + fn content_to_text(content: &aoc::ChatCompletionRequestDeveloperMessageContent) -> String { + match content { + aoc::ChatCompletionRequestDeveloperMessageContent::Text(text) => text.clone(), + aoc::ChatCompletionRequestDeveloperMessageContent::Array(parts) => parts + .iter() + .map(|part| match part { + aoc::ChatCompletionRequestDeveloperMessageContentPart::Text(part) => { + part.text.clone() + } + }) + .collect::>() + .join(""), + } + } + + fn content_to_text_system(content: &aoc::ChatCompletionRequestSystemMessageContent) -> String { + match content { + aoc::ChatCompletionRequestSystemMessageContent::Text(text) => text.clone(), + aoc::ChatCompletionRequestSystemMessageContent::Array(parts) => parts + .iter() + .map(|part| match part { + aoc::ChatCompletionRequestSystemMessageContentPart::Text(part) => { + part.text.clone() + } + }) + .collect::>() + .join(""), + } + } + + fn content_to_text_user(content: &aoc::ChatCompletionRequestUserMessageContent) -> String { + match content { + aoc::ChatCompletionRequestUserMessageContent::Text(text) => text.clone(), + aoc::ChatCompletionRequestUserMessageContent::Array(parts) => parts + .iter() + .filter_map(|part| match part { + aoc::ChatCompletionRequestUserMessageContentPart::Text(part) => { + Some(part.text.clone()) + } + _ => None, + }) + .collect::>() + .join(""), + } + } + + fn content_to_text_assistant( + content: &aoc::ChatCompletionRequestAssistantMessageContent, + ) -> String { + match content { + aoc::ChatCompletionRequestAssistantMessageContent::Text(text) => text.clone(), + aoc::ChatCompletionRequestAssistantMessageContent::Array(parts) => parts + .iter() + .map(|part| match part { + aoc::ChatCompletionRequestAssistantMessageContentPart::Text(part) => { + part.text.clone() + } + aoc::ChatCompletionRequestAssistantMessageContentPart::Refusal(part) => { + part.refusal.clone() + } + }) + .collect::>() + .join(""), + } + } + + fn content_to_text_tool(content: &aoc::ChatCompletionRequestToolMessageContent) -> String { + match content { + aoc::ChatCompletionRequestToolMessageContent::Text(text) => text.clone(), + aoc::ChatCompletionRequestToolMessageContent::Array(parts) => parts + .iter() + .map(|part| match part { + aoc::ChatCompletionRequestToolMessageContentPart::Text(part) => { + part.text.clone() + } + }) + .collect::>() + .join(""), + } + } + + impl From<&aoc::FunctionCall> for super::FunctionCall { + fn from(f: &aoc::FunctionCall) -> Self { Self { name: f.name.clone(), arguments: f.arguments.clone(), @@ -378,15 +460,54 @@ pub mod async_openai { } } - impl From<&async_openai::types::ChatCompletionRequestMessage> - for super::ChatCompletionRequestMessage - { - fn from(m: &async_openai::types::ChatCompletionRequestMessage) -> Self { - Self { - role: m.role.to_string(), - name: m.name.clone(), - content: m.content.clone(), - function_call: m.function_call.as_ref().map(|f| f.into()), + impl From<&aoc::ChatCompletionRequestMessage> for super::ChatCompletionRequestMessage { + fn from(m: &aoc::ChatCompletionRequestMessage) -> Self { + match m { + aoc::ChatCompletionRequestMessage::Developer(m) => Self { + role: "developer".to_string(), + name: m.name.clone(), + content: Some(content_to_text(&m.content)), + function_call: None, + }, + aoc::ChatCompletionRequestMessage::System(m) => Self { + role: "system".to_string(), + name: m.name.clone(), + content: Some(content_to_text_system(&m.content)), + function_call: None, + }, + aoc::ChatCompletionRequestMessage::User(m) => Self { + role: "user".to_string(), + name: m.name.clone(), + content: Some(content_to_text_user(&m.content)), + function_call: None, + }, + aoc::ChatCompletionRequestMessage::Assistant(m) => { + let content = m + .content + .as_ref() + .map(content_to_text_assistant) + .or_else(|| m.refusal.clone()); + #[allow(deprecated)] + let function_call = m.function_call.as_ref().map(|f| f.into()); + Self { + role: "assistant".to_string(), + name: m.name.clone(), + content, + function_call, + } + } + aoc::ChatCompletionRequestMessage::Tool(m) => Self { + role: "tool".to_string(), + name: None, + content: Some(content_to_text_tool(&m.content)), + function_call: None, + }, + aoc::ChatCompletionRequestMessage::Function(m) => Self { + role: "function".to_string(), + name: Some(m.name.clone()), + content: m.content.clone(), + function_call: None, + }, } } } @@ -396,14 +517,14 @@ pub mod async_openai { /// # Arguments /// /// * `model` - A string slice representing the name of the model. - /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances. + /// * `messages` - A slice of `async_openai::types::chat::ChatCompletionRequestMessage` instances. /// /// # Returns /// /// * A `Result` containing the total number of tokens (`usize`) or an error if the calculation fails. pub fn num_tokens_from_messages( model: &str, - messages: &[async_openai::types::ChatCompletionRequestMessage], + messages: &[aoc::ChatCompletionRequestMessage], ) -> Result { let messages = messages.iter().map(|m| m.into()).collect::>(); super::num_tokens_from_messages(model, &messages) @@ -414,14 +535,14 @@ pub mod async_openai { /// # Arguments /// /// * `model` - A string slice representing the name of the model. - /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances. + /// * `messages` - A slice of `async_openai::types::chat::ChatCompletionRequestMessage` instances. /// /// # Returns /// /// * A `Result` containing the maximum number of tokens (`usize`) allowed for chat completions or an error if the retrieval fails. pub fn get_chat_completion_max_tokens( model: &str, - messages: &[async_openai::types::ChatCompletionRequestMessage], + messages: &[aoc::ChatCompletionRequestMessage], ) -> Result { let messages = messages.iter().map(|m| m.into()).collect::>(); super::get_chat_completion_max_tokens(model, &messages) @@ -434,12 +555,13 @@ pub mod async_openai { #[test] fn test_num_tokens_from_messages() { let model = "gpt-3.5-turbo-0301"; - let messages = &[async_openai::types::ChatCompletionRequestMessage { - role: async_openai::types::Role::System, - name: None, - content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()), - function_call: None, - }]; + let messages = &[aoc::ChatCompletionRequestSystemMessageArgs::default() + .content( + "You are a helpful, pattern-following assistant that translates corporate jargon into plain English.", + ) + .build() + .unwrap() + .into()]; let num_tokens = num_tokens_from_messages(model, messages).unwrap(); assert!(num_tokens > 0); } @@ -447,12 +569,11 @@ pub mod async_openai { #[test] fn test_get_chat_completion_max_tokens() { let model = "gpt-3.5-turbo"; - let messages = &[async_openai::types::ChatCompletionRequestMessage { - content: Some("You are a helpful assistant that only speaks French.".to_string()), - role: async_openai::types::Role::System, - name: None, - function_call: None, - }]; + let messages = &[aoc::ChatCompletionRequestSystemMessageArgs::default() + .content("You are a helpful assistant that only speaks French.") + .build() + .unwrap() + .into()]; let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap(); assert!(max_tokens > 0); }