Skip to content

Commit

Permalink
chore: generate ai image
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy committed Feb 26, 2025
1 parent 1c38cdd commit d825172
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 18 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions libs/appflowy-ai-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,19 @@ impl AppFlowyAIClient {
.into_data()
}

pub async fn regenerate_image(&self, source_metadata: Value) -> Result<(), AIError> {
let url = format!("{}/chat/image/regenerate", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.json(&source_metadata)
.timeout(Duration::from_secs(30))
.send()
.await?;
AIResponse::<()>::from_reqwest_response(resp)
.await?
.into_data()
}

pub async fn get_local_ai_package(
&self,
platform: &str,
Expand Down
29 changes: 27 additions & 2 deletions libs/database/src/pg_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,40 @@ pub struct AFCollabMemberRow {
#[repr(i16)]
pub enum AFBlobStatus {
Ok = 0,
DallEContentPolicyViolation = 1,
PolicyViolation = 1,
Failed = 2,
Pending = 3,
}

impl From<i16> for AFBlobStatus {
fn from(value: i16) -> Self {
match value {
0 => AFBlobStatus::Ok,
1 => AFBlobStatus::DallEContentPolicyViolation,
1 => AFBlobStatus::PolicyViolation,
2 => AFBlobStatus::Failed,
3 => AFBlobStatus::Pending,
_ => AFBlobStatus::Ok,
}
}
}

#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Clone)]
#[repr(i16)]
pub enum AFBlobSource {
UserUpload = 0,
AIGen = 1,
}

impl From<i16> for AFBlobSource {
fn from(value: i16) -> Self {
match value {
0 => AFBlobSource::UserUpload,
1 => AFBlobSource::AIGen,
_ => AFBlobSource::UserUpload,
}
}
}

#[derive(Debug, FromRow, Serialize, Deserialize)]
pub struct AFBlobMetadataRow {
pub workspace_id: Uuid,
Expand All @@ -224,6 +245,10 @@ pub struct AFBlobMetadataRow {
pub modified_at: DateTime<Utc>,
#[serde(default)]
pub status: i16,
#[serde(default)]
pub source: i16,
#[serde(default)]
pub source_metadata: serde_json::Value,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
Expand Down
4 changes: 4 additions & 0 deletions migrations/20250226091933_blob_metadata_add_file_source.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Add migration script here
ALTER TABLE af_blob_metadata
ADD COLUMN source SMALLINT NOT NULL DEFAULT 0,
ADD COLUMN source_metadata JSONB DEFAULT '{}'::jsonb;
42 changes: 35 additions & 7 deletions src/api/file_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ use database_entity::file_dto::{
use crate::biz::data_import::LimitedPayload;
use crate::state::AppState;
use anyhow::anyhow;
use appflowy_ai_client::client::AppFlowyAIClient;
use aws_sdk_s3::primitives::ByteStream;
use collab_importer::util::FileId;
use database::pg_row::AFBlobStatus;
use database::pg_row::{AFBlobSource, AFBlobStatus};
use serde::Deserialize;
use shared_entity::dto::file_dto::PutFileResponse;
use shared_entity::dto::workspace_dto::{BlobMetadata, RepeatedBlobMetaData, WorkspaceSpaceUsage};
Expand All @@ -35,7 +36,7 @@ use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_stream::StreamExt;
use tokio_util::io::StreamReader;
use tracing::{error, event, instrument, trace};
use tracing::{error, event, info, instrument, trace};

pub fn file_storage_scope() -> Scope {
web::scope("/api/file_storage")
Expand Down Expand Up @@ -370,12 +371,39 @@ async fn get_blob_by_object_key(
}

let metadata = result.unwrap();
match AFBlobStatus::from(metadata.status) {
AFBlobStatus::DallEContentPolicyViolation => {
return Ok(HttpResponse::UnprocessableEntity().finish());
let source = AFBlobSource::from(metadata.source);
trace!("blob metadata: {:?}", metadata);
match source {
AFBlobSource::UserUpload => {},
AFBlobSource::AIGen => {
let spawn_regenerate_image =
|client: AppFlowyAIClient, source_metadata: serde_json::Value| {
tokio::spawn(async move {
info!("Regenerate ai image: {:?}", source_metadata);
let _ = client.regenerate_image(source_metadata).await;
});
};
let source_metadata = metadata.source_metadata;
let status = AFBlobStatus::from(metadata.status);
trace!("AI image {}: {:?}", key.object_key(), status);
match status {
AFBlobStatus::PolicyViolation => {
return Ok(HttpResponse::UnprocessableEntity().finish());
},
AFBlobStatus::Pending => {
if metadata.modified_at + chrono::Duration::minutes(1) < chrono::Utc::now() {
spawn_regenerate_image(state.ai_client.clone(), source_metadata);
} else {
trace!("AI image is pending, wait for 1 minute");
}
},
AFBlobStatus::Failed => {
spawn_regenerate_image(state.ai_client.clone(), source_metadata);
},
_ => {},
};
},
AFBlobStatus::Ok => {},
};
}

// Check if the file is modified since the last time
if let Some(modified_since) = req
Expand Down
56 changes: 49 additions & 7 deletions tests/ai_test/chat_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,45 @@ async fn generate_chat_message_answer_test() {
.stream_answer_v2(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
let answer = collect_answer(answer_stream, None).await;
assert!(!answer.is_empty());
}

// #[tokio::test]
// async fn stop_streaming_test() {
// if !ai_test_enabled() {
// return;
// }
// let test_client = TestClient::new_user_without_ws_conn().await;
// let workspace_id = test_client.workspace_id().await;
// let chat_id = uuid::Uuid::new_v4().to_string();
// let params = CreateChatParams {
// chat_id: chat_id.clone(),
// name: "Stop streaming test".to_string(),
// rag_ids: vec![],
// };
//
// test_client
// .api_client
// .create_chat(&workspace_id, params)
// .await
// .unwrap();
// let params = CreateChatMessageParams::new_user("when to use js");
// let question = test_client
// .api_client
// .create_question(&workspace_id, &chat_id, params)
// .await
// .unwrap();
// let answer_stream = test_client
// .api_client
// .stream_answer_v2(&workspace_id, &chat_id, question.message_id)
// .await
// .unwrap();
// let answer = collect_answer(answer_stream, Some(1)).await;
// println!("answer:\n{}", answer);
// assert!(!answer.is_empty());
// }

#[tokio::test]
async fn get_format_question_message_test() {
if !ai_test_enabled() {
Expand Down Expand Up @@ -325,7 +360,7 @@ async fn get_format_question_message_test() {
.stream_answer_v3(&workspace_id, query)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
let answer = collect_answer(answer_stream, None).await;
println!("answer:\n{}", answer);
assert!(!answer.is_empty());
}
Expand Down Expand Up @@ -380,7 +415,7 @@ async fn get_text_with_image_message_test() {
.stream_answer_v3(&workspace_id, query)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
let answer = collect_answer(answer_stream, None).await;
println!("answer:\n{}", answer);
let image_url = extract_image_url(&answer).unwrap();
let (workspace_id_2, chat_id_2, file_id_2) = test_client
Expand Down Expand Up @@ -502,15 +537,22 @@ async fn get_model_list_test() {
println!("models: {:?}", models);
}

async fn collect_answer(mut stream: QuestionStream) -> String {
async fn collect_answer(mut stream: QuestionStream, stop_when_num_of_char: Option<u8>) -> String {
let mut answer = String::new();
let mut num_of_char = 0;
while let Some(value) = stream.next().await {
match value.unwrap() {
num_of_char += match value.unwrap() {
QuestionStreamValue::Answer { value } => {
answer.push_str(&value);
value.len() as u8
},
QuestionStreamValue::Metadata { .. } => {},
QuestionStreamValue::KeepAlive => {},
QuestionStreamValue::Metadata { .. } => 0,
QuestionStreamValue::KeepAlive => 0,
};
if let Some(stop_when_num_of_char) = stop_when_num_of_char {
if num_of_char >= stop_when_num_of_char {
break;
}
}
}
answer
Expand Down

0 comments on commit d825172

Please sign in to comment.