Skip to content

Commit

Permalink
feat: add Azure OpenAI Content Filter Support (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasharab authored Jun 17, 2024
1 parent 85a0f47 commit 61f5759
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ internal class ChatMessageAssembler {
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null
private val toolCallsAssemblers = mutableMapOf<Int, ToolCallAssembler>()
private var chatContentFilterOffsets = mutableListOf<ContentFilterOffsets>()
private var chatContentFilterResults = mutableListOf<ContentFilterResults>()

/**
* Merges a chat chunk into the chat message being assembled.
*/
fun merge(chunk: ChatChunk): ChatMessageAssembler {
chunk.delta.run {
chunk.delta?.run {
role?.let { chatRole = it }
content?.let { chatContent.append(it) }
functionCall?.let { call ->
Expand All @@ -30,6 +32,12 @@ internal class ChatMessageAssembler {
assembler.merge(toolCall)
}
}
chunk.contentFilterOffsets?.also {
chatContentFilterOffsets.add(it)
}
chunk.contentFilterResults?.also {
chatContentFilterResults.add(it)
}
return this
}

Expand All @@ -39,6 +47,8 @@ internal class ChatMessageAssembler {
fun build(): ChatMessage = chatMessage {
this.role = chatRole
this.content = chatContent.toString()
this.contentFilterOffsets = chatContentFilterOffsets
this.contentFilterResults = chatContentFilterResults
if (chatFuncName.isNotEmpty() || chatFuncArgs.isNotEmpty()) {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import com.aallam.openai.api.chat.ChatChunk
import com.aallam.openai.api.chat.ChatDelta
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.ContentFilterOffsets
import com.aallam.openai.api.chat.ContentFilterResult
import com.aallam.openai.api.chat.ContentFilterResults
import com.aallam.openai.api.core.FinishReason
import com.aallam.openai.client.extension.mergeToChatMessage
import kotlin.test.Test
Expand All @@ -20,6 +23,8 @@ class TestChatChunk {
role = ChatRole(role = "assistant"),
content = ""
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -28,6 +33,8 @@ class TestChatChunk {
role = null,
content = "The"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -36,6 +43,8 @@ class TestChatChunk {
role = null,
content = " World"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -44,6 +53,8 @@ class TestChatChunk {
role = null,
content = " Series"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -52,6 +63,8 @@ class TestChatChunk {
role = null,
content = " in"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -60,6 +73,8 @@ class TestChatChunk {
role = null,
content = " "
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -68,6 +83,8 @@ class TestChatChunk {
role = null,
content = "202"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -76,6 +93,8 @@ class TestChatChunk {
role = null,
content = "0"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -84,6 +103,8 @@ class TestChatChunk {
role = null,
content = " is"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -92,6 +113,8 @@ class TestChatChunk {
role = null,
content = " being held"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -100,6 +123,8 @@ class TestChatChunk {
role = null,
content = " in"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -108,6 +133,8 @@ class TestChatChunk {
role = null,
content = " Texas"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -116,6 +143,8 @@ class TestChatChunk {
role = null,
content = "."
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -124,6 +153,24 @@ class TestChatChunk {
role = null,
content = null
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = FinishReason(value = "stop")
),
ChatChunk(
index = 0,
delta = null,
contentFilterOffsets = ContentFilterOffsets(
checkOffset = 1,
startOffset = 1,
endOffset = 1,
),
contentFilterResults = ContentFilterResults(
hate = ContentFilterResult(
filtered = false,
severity = "high",
)
),
finishReason = FinishReason(value = "stop")
)
)
Expand All @@ -132,6 +179,21 @@ class TestChatChunk {
role = ChatRole.Assistant,
content = "The World Series in 2020 is being held in Texas.",
name = null,
contentFilterResults = listOf(
ContentFilterResults(
hate = ContentFilterResult(
filtered = false,
severity = "high",
)
)
),
contentFilterOffsets = listOf(
ContentFilterOffsets(
checkOffset = 1,
startOffset = 1,
endOffset = 1,
)
),
)
assertEquals(chatMessage, message)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.aallam.openai.client

import com.aallam.openai.api.chat.ChatCompletionChunk
import com.aallam.openai.api.file.FileSource
import com.aallam.openai.client.internal.JsonLenient
import com.aallam.openai.client.internal.TestFileSystem
import com.aallam.openai.client.internal.testFilePath
import kotlin.test.Test
import okio.buffer

class TestChatCompletionChunk {
@Test
fun testContentFilterDeserialization() {
val json = FileSource(path = testFilePath("json/azureContentFilterChunk.json"), fileSystem = TestFileSystem)
val actualJson = json.source.buffer().readByteArray().decodeToString()
JsonLenient.decodeFromString<ChatCompletionChunk>(actualJson)
}

@Test
fun testDeserialization() {
val json = FileSource(path = testFilePath("json/chatChunk.json"), fileSystem = TestFileSystem)
val actualJson = json.source.buffer().readByteArray().decodeToString()
JsonLenient.decodeFromString<ChatCompletionChunk>(actualJson)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"choices": [
{
"content_filter_offsets": {
"check_offset": 33188,
"start_offset": 33188,
"end_offset": 33557
},
"content_filter_results": {
"hate": {
"filtered": false,
"severity": "safe"
},
"self_harm": {
"filtered": false,
"severity": "safe"
},
"sexual": {
"filtered": false,
"severity": "safe"
},
"violence": {
"filtered": false,
"severity": "safe"
}
},
"finish_reason": null,
"index": 0
}
],
"created": 0,
"id": "",
"model": "",
"object": ""
}
16 changes: 16 additions & 0 deletions openai-client/src/commonTest/resources/json/chatChunk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"choices": [
{
"delta": {
"content": " engineering"
},
"finish_reason": null,
"index": 0
}
],
"created": 1716855566,
"id": "chatcmpl-9TeqkT3BJs5zXQq12b204deXcY5nj",
"model": "gpt-4o-2024-05-13",
"object": "chat.completion.chunk",
"system_fingerprint": "fp_5f4bad809a"
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.aallam.openai.api.chat;

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.core.FinishReason
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand All @@ -19,7 +18,17 @@ public data class ChatChunk(
/**
* The generated chat message.
*/
@SerialName("delta") public val delta: ChatDelta,
@SerialName("delta") public val delta: ChatDelta? = null,

/**
* Azure content filter offsets
*/
@SerialName("content_filter_offsets") public val contentFilterOffsets: ContentFilterOffsets? = null,

/**
* Azure content filter results
*/
@SerialName("content_filter_results") public val contentFilterResults: ContentFilterResults? = null,

/**
* The reason why OpenAI stopped generating.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ public data class ChatMessage(
* Tool call ID.
*/
@SerialName("tool_call_id") public val toolCallId: ToolId? = null,

/**
* Azure Content Filter Results
*/
@SerialName("content_filter_results") public val contentFilterResults: List<ContentFilterResults>? = null,

/**
* Azure Content Filter Offsets
*/
@SerialName("content_filter_offsets") public val contentFilterOffsets: List<ContentFilterOffsets>? = null,
) {

public constructor(
Expand All @@ -54,13 +64,17 @@ public data class ChatMessage(
functionCall: FunctionCall? = null,
toolCalls: List<ToolCall>? = null,
toolCallId: ToolId? = null,
contentFilterResults: List<ContentFilterResults>? = null,
contentFilterOffsets: List<ContentFilterOffsets>? = null,
) : this(
role = role,
messageContent = content?.let { TextContent(it) },
name = name,
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)

public constructor(
Expand All @@ -70,13 +84,17 @@ public data class ChatMessage(
functionCall: FunctionCall? = null,
toolCalls: List<ToolCall>? = null,
toolCallId: ToolId? = null,
contentFilterResults: List<ContentFilterResults>? = null,
contentFilterOffsets: List<ContentFilterOffsets>? = null,
) : this(
role = role,
messageContent = content?.let { ListContent(it) },
name = name,
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)

val content: String?
Expand Down Expand Up @@ -282,6 +300,16 @@ public class ChatMessageBuilder {
*/
public var toolCalls: List<ToolCall>? = null

/**
* Azure content filter offsets
*/
public var contentFilterOffsets: List<ContentFilterOffsets>? = null

/**
* Azure content filter results
*/
public var contentFilterResults: List<ContentFilterResults>? = null

/**
* Tool call ID.
*/
Expand Down Expand Up @@ -313,6 +341,8 @@ public class ChatMessageBuilder {
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)
}
}
Expand Down
Loading

0 comments on commit 61f5759

Please sign in to comment.