Skip to content

Commit

Permalink
refactoring to use workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
jbilcke-hf committed Aug 8, 2024
1 parent 8722988 commit a8ff2d1
Show file tree
Hide file tree
Showing 70 changed files with 1,768 additions and 2,608 deletions.
837 changes: 113 additions & 724 deletions package-lock.json

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
"electron:make": "npm run build && electron-forge make"
},
"dependencies": {
"@aitube/broadway": "0.1.3-1",
"@aitube/clap": "0.1.2",
"@aitube/clapper-services": "0.1.6",
"@aitube/engine": "0.1.2",
"@aitube/timeline": "0.1.3",
"@aitube/broadway": "0.2.0",
"@aitube/clap": "0.2.0",
"@aitube/clapper-services": "0.2.0-2",
"@aitube/engine": "0.2.0",
"@aitube/timeline": "0.2.0",
"@fal-ai/serverless-client": "^0.13.0",
"@ffmpeg/ffmpeg": "^0.12.10",
"@ffmpeg/util": "^0.12.1",
"@gradio/client": "^1.4.0",
"@gradio/client": "^1.5.0",
"@huggingface/hub": "^0.15.1",
"@huggingface/inference": "^2.8.0",
"@langchain/anthropic": "^0.2.12",
Expand Down
31 changes: 16 additions & 15 deletions src/app/api/assistant/askAnyAssistant.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'use server'

import { ClapSegmentCategory } from '@aitube/clap'
import { ClapWorkflowProvider } from '@aitube/clap'
import { RunnableLike } from '@langchain/core/runnables'
import { ChatPromptValueInterface } from '@langchain/core/prompt_values'
import {
Expand Down Expand Up @@ -28,7 +28,6 @@ import {
AssistantRequest,
AssistantSceneSegment,
AssistantStoryBlock,
ComputeProvider,
ChatEventVisibility,
} from '@aitube/clapper-services'

Expand Down Expand Up @@ -65,7 +64,9 @@ export async function askAnyAssistant({

history = [],
}: AssistantRequest): Promise<AssistantMessage> {
const provider = settings.assistantProvider
const workflow = settings.assistantWorkflow
const provider = workflow.provider
const modelName = workflow.data

if (!provider) {
throw new Error(`Missing assistant provider`)
Expand All @@ -74,40 +75,40 @@ export async function askAnyAssistant({
let coerceable:
| undefined
| RunnableLike<ChatPromptValueInterface, AIMessageChunk> =
provider === ComputeProvider.GROQ
provider === ClapWorkflowProvider.GROQ
? new ChatGroq({
apiKey: settings.groqApiKey,
modelName: settings.assistantModel,
modelName,
// temperature: 0.7,
})
: provider === ComputeProvider.OPENAI
: provider === ClapWorkflowProvider.OPENAI
? new ChatOpenAI({
openAIApiKey: settings.openaiApiKey,
modelName: settings.assistantModel,
modelName,
// temperature: 0.7,
})
: provider === ComputeProvider.ANTHROPIC
: provider === ClapWorkflowProvider.ANTHROPIC
? new ChatAnthropic({
anthropicApiKey: settings.anthropicApiKey,
modelName: settings.assistantModel,
modelName,
// temperature: 0.7,
})
: provider === ComputeProvider.COHERE
: provider === ClapWorkflowProvider.COHERE
? new ChatCohere({
apiKey: settings.cohereApiKey,
model: settings.assistantModel,
model: modelName,
// temperature: 0.7,
})
: provider === ComputeProvider.MISTRALAI
: provider === ClapWorkflowProvider.MISTRALAI
? new ChatMistralAI({
apiKey: settings.mistralAiApiKey,
modelName: settings.assistantModel,
modelName,
// temperature: 0.7,
})
: provider === ComputeProvider.GOOGLE
: provider === ClapWorkflowProvider.GOOGLE
? new ChatVertexAI({
apiKey: settings.googleApiKey,
modelName: settings.assistantModel,
modelName,
// temperature: 0.7,
})
: undefined
Expand Down
36 changes: 23 additions & 13 deletions src/app/api/resolve/providers/falai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ export async function resolveSegment(

const segment: TimelineSegment = request.segment

let model = request.settings.imageGenerationWorkflow.data || ''

// for doc see:
// https://fal.ai/models/fal-ai/fast-sdxl/api

if (request.segment.category === ClapSegmentCategory.STORYBOARD) {
model = request.settings.imageGenerationWorkflow.data || ''

if (!request.prompts.image.positive) {
console.error(
`resolveSegment: cannot resolve a storyboard with an empty prompt`
Expand All @@ -42,18 +46,18 @@ export async function resolveSegment(

let result: FalAiImageResponse | undefined = undefined

if (request.settings.imageGenerationModel === 'fal-ai/pulid') {
if (model === 'fal-ai/pulid') {
if (!request.prompts.image.identity) {
// throw new Error(`you selected model ${request.settings.falAiModelForImage}, but no character was found, so skipping`)
// console.log(`warning: user selected model ${request.settings.falAiModelForImage}, but no character was found. Falling back to fal-ai/flux-pro`)

// dirty fix to fallback to a non-face model
request.settings.imageGenerationModel = 'fal-ai/flux-pro'
model = 'fal-ai/flux-pro'
}
}

if (request.settings.imageGenerationModel === 'fal-ai/pulid') {
result = (await fal.run(request.settings.imageGenerationModel, {
if (model === 'fal-ai/pulid') {
result = (await fal.run(model, {
input: {
reference_images: [
{
Expand All @@ -68,16 +72,13 @@ export async function resolveSegment(
},
})) as FalAiImageResponse
} else {
result = (await fal.run(request.settings.imageGenerationModel, {
result = (await fal.run(model, {
input: {
prompt: request.prompts.image.positive,
image_size: imageSize,
sync_mode: true,
num_inference_steps:
request.settings.imageGenerationModel ===
'fal-ai/stable-diffusion-v3-medium'
? 40
: 25,
model === 'fal-ai/stable-diffusion-v3-medium' ? 40 : 25,
num_images: 1,
enable_safety_checker:
request.settings.censorNotForAllAudiencesContent,
Expand All @@ -95,8 +96,10 @@ export async function resolveSegment(

segment.assetUrl = result.images[0]?.url || ''
} else if (request.segment.category === ClapSegmentCategory.VIDEO) {
model = request.settings.videoGenerationWorkflow.data || ''

// console.log(`request.settings.falAiModelForVideo = `, request.settings.falAiModelForVideo)
if (request.settings.videoGenerationModel !== 'fal-ai/stable-video') {
if (model !== 'fal-ai/stable-video') {
throw new Error(
`only "fal-ai/stable-video" is supported by Clapper for the moment`
)
Expand All @@ -110,7 +113,7 @@ export async function resolveSegment(
`cannot generate a video without a storyboard (the concept of Clapper is to use storyboards)`
)
}
const result = (await fal.run(request.settings.videoGenerationModel, {
const result = (await fal.run(model, {
input: {
image_url: storyboard.assetUrl,

Expand Down Expand Up @@ -141,7 +144,12 @@ export async function resolveSegment(
request.segment.category === ClapSegmentCategory.SOUND ||
request.segment.category === ClapSegmentCategory.MUSIC
) {
const result = (await fal.run(request.settings.soundGenerationModel, {
model =
request.segment.category === ClapSegmentCategory.MUSIC
? request.settings.musicGenerationWorkflow.data
: request.settings.soundGenerationWorkflow.data

const result = (await fal.run(model, {
input: {
// note how we use the *segment* prompt for music or sound
prompt: request.segment.prompt,
Expand All @@ -153,7 +161,9 @@ export async function resolveSegment(

segment.assetUrl = result?.audio_file?.url || ''
} else if (request.segment.category === ClapSegmentCategory.DIALOGUE) {
const result = (await fal.run(request.settings.voiceGenerationModel, {
model = request.settings.voiceGenerationWorkflow.data || ''

const result = (await fal.run(model, {
input: {
text: request.segment.prompt,

Expand Down
6 changes: 3 additions & 3 deletions src/app/api/resolve/providers/huggingface/generateImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import { decodeOutput } from '@/lib/utils/decodeOutput'
import { ResolveRequest } from '@aitube/clapper-services'

export async function generateImage(request: ResolveRequest): Promise<string> {
if (!request.settings.imageGenerationModel) {
if (!request.settings.imageGenerationWorkflow.data) {
throw new Error(
`HuggingFace.generateImage: cannot generate without a valid imageGenerationModel`
`HuggingFace.generateImage: cannot generate without a valid imageGenerationWorkflow`
)
}

Expand All @@ -27,7 +27,7 @@ export async function generateImage(request: ResolveRequest): Promise<string> {
)

const blob: Blob = await hf.textToImage({
model: request.settings.imageGenerationModel,
model: request.settings.imageGenerationWorkflow.data,
inputs: request.prompts.image.positive,
parameters: {
height: request.meta.height,
Expand Down
6 changes: 3 additions & 3 deletions src/app/api/resolve/providers/huggingface/generateVideo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import { ResolveRequest } from '@aitube/clapper-services'
import { callGradioApi } from '@/lib/hf/callGradioApi'

export async function generateVideo(request: ResolveRequest): Promise<string> {
if (!request.settings.videoGenerationModel) {
if (!request.settings.videoGenerationWorkflow.data) {
throw new Error(
`HuggingFace.generateVideo: cannot generate without a valid videoGenerationModel`
`HuggingFace.generateVideo: cannot generate without a valid videoGenerationWorkflow.data`
)
}

Expand All @@ -22,7 +22,7 @@ export async function generateVideo(request: ResolveRequest): Promise<string> {

// TODO pass a type to the template function
const assetUrl = await callGradioApi<string>({
url: request.settings.videoGenerationModel,
url: request.settings.videoGenerationWorkflow.data,
inputs: request.prompts.video,
apiKey: request.settings.huggingFaceApiKey,
})
Expand Down
6 changes: 3 additions & 3 deletions src/app/api/resolve/providers/huggingface/generateVoice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import { HfInference, HfInferenceEndpoint } from '@huggingface/inference'
import { ResolveRequest } from '@aitube/clapper-services'

export async function generateVoice(request: ResolveRequest): Promise<string> {
if (!request.settings.voiceGenerationModel) {
if (!request.settings.voiceGenerationWorkflow.data) {
throw new Error(
`HuggingFace.generateVoice: cannot generate without a valid voiceGenerationModel`
`HuggingFace.generateVoice: cannot generate without a valid voiceGenerationWorkflow`
)
}

Expand All @@ -26,7 +26,7 @@ export async function generateVoice(request: ResolveRequest): Promise<string> {
)

const blob: Blob = await hf.textToSpeech({
model: request.settings.voiceGenerationModel,
model: request.settings.voiceGenerationWorkflow.data,
inputs: request.prompts.voice.positive,
})

Expand Down
2 changes: 1 addition & 1 deletion src/app/api/resolve/providers/huggingface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export async function resolveSegment(
segment.assetUrl = await generateVideo(request)
} else {
throw new Error(
`Clapper doesn't support ${request.segment.category} generation for provider "Hugging Face" with model (or space) "${request.settings.videoGenerationModel}". Please open a pull request with (working code) to solve this!`
`Clapper doesn't support ${request.segment.category} generation for provider "Hugging Face" with model (or space) "${request.settings.videoGenerationWorkflow}". Please open a pull request with (working code) to solve this!`
)
}
return segment
Expand Down
14 changes: 9 additions & 5 deletions src/app/api/resolve/providers/replicate/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ export async function resolveSegment(
// like we are doing for Hugging Face (match the fields etc)
if (request.segment.category === ClapSegmentCategory.STORYBOARD) {
let params: object = {}
if (request.settings.imageGenerationModel === 'fofr/pulid-lightning') {
if (
request.settings.imageGenerationWorkflow.data === 'fofr/pulid-lightning'
) {
params = {
prompt: request.prompts.image.positive,
face_image: request.prompts.image.identity,
}
} else if (request.settings.imageGenerationModel === 'zsxkib/pulid') {
} else if (
request.settings.imageGenerationWorkflow.data === 'zsxkib/pulid'
) {
params = {
prompt: request.prompts.image.positive,
main_face_image: request.prompts.image.identity,
Expand All @@ -40,13 +44,13 @@ export async function resolveSegment(
}
}
const response = (await replicate.run(
request.settings.imageGenerationModel as any,
request.settings.imageGenerationWorkflow as any,
{ input: params }
)) as any
segment.assetUrl = `${response.output || ''}`
} else if (request.segment.category === ClapSegmentCategory.DIALOGUE) {
const response = (await replicate.run(
request.settings.voiceGenerationModel as any,
request.settings.voiceGenerationWorkflow.data as any,
{
input: {
text: request.prompts.voice.positive,
Expand All @@ -57,7 +61,7 @@ export async function resolveSegment(
segment.assetUrl = `${response.output || ''}`
} else if (request.segment.category === ClapSegmentCategory.VIDEO) {
const response = (await replicate.run(
request.settings.videoGenerationModel as any,
request.settings.videoGenerationWorkflow.data as any,
{
input: {
image: request.prompts.video.image,
Expand Down
4 changes: 2 additions & 2 deletions src/app/api/resolve/providers/stabilityai/generateImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export async function generateImage(request: ResolveRequest): Promise<string> {
)
}

if (!request.settings.imageGenerationModel) {
if (!request.settings.imageGenerationWorkflow.data) {
throw new Error(
`StabilityAI.generateImage: cannot generate without a valid stabilityAiModelForImage`
)
Expand Down Expand Up @@ -45,7 +45,7 @@ export async function generateImage(request: ResolveRequest): Promise<string> {
body.set('aspect_ratio', `${aspectRatio || ''}`)

const response = await fetch(
`https://api.stability.ai/v2beta/${request.settings.imageGenerationModel}`,
`https://api.stability.ai/v2beta/${request.settings.imageGenerationWorkflow.data}`,
{
method: 'POST',
headers: {
Expand Down
4 changes: 2 additions & 2 deletions src/app/api/resolve/providers/stabilityai/generateVideo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ export async function generateVideo(request: ResolveRequest): Promise<string> {
throw new Error(`${TAG}: cannot generate without a valid stabilityAiApiKey`)
}

if (!request.settings.videoGenerationModel) {
if (!request.settings.videoGenerationWorkflow.data) {
throw new Error(
`${TAG}: cannot generate without a valid videoGenerationModel`
`${TAG}: cannot generate without a valid videoGenerationWorkflow`
)
}

Expand Down
Loading

0 comments on commit a8ff2d1

Please sign in to comment.