Skip to content

Commit

Permalink
add face swap for character consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
jbilcke-hf committed Aug 25, 2024
1 parent 553016c commit 3e6b235
Show file tree
Hide file tree
Showing 19 changed files with 518 additions and 113 deletions.
66 changes: 59 additions & 7 deletions packages/app/src/app/api/resolve/providers/falai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { ClapMediaOrientation, ClapSegmentCategory } from '@aitube/clap'
import {
FalAiAudioResponse,
FalAiImageResponse,
FalAiImagesResponse,
FalAiSpeechResponse,
FalAiVideoResponse,
} from './types'
Expand Down Expand Up @@ -40,15 +41,19 @@ export async function resolveSegment(
return segment
}

let result: FalAiImageResponse | undefined = undefined
let result: FalAiImagesResponse | undefined = undefined

let isUsingIntegratedFaceId = false

if (model === 'fal-ai/pulid') {
if (!request.prompts.image.identity) {
// throw new Error(`you selected model ${request.settings.falAiModelForImage}, but no character was found, so skipping`)
// console.log(`warning: user selected model ${request.settings.falAiModelForImage}, but no character was found. Falling back to fal-ai/flux-pro`)

// dirty fix to fallback to a non-face model
model = 'fal-ai/flux-pro'
model = 'fal-ai/flux/schnell'
} else {
isUsingIntegratedFaceId = true
}
}

Expand Down Expand Up @@ -100,7 +105,7 @@ export async function resolveSegment(
enable_safety_checker:
request.settings.censorNotForAllAudiencesContent,
},
})) as FalAiImageResponse
})) as FalAiImagesResponse
} else if (model === 'fal-ai/flux-general') {
// note: this isn't the right place to do this, because maybe the LoRAs are dynamic
const loraModel = getWorkflowLora(
Expand Down Expand Up @@ -135,7 +140,7 @@ export async function resolveSegment(
enable_safety_checker:
request.settings.censorNotForAllAudiencesContent,
},
})) as FalAiImageResponse
})) as FalAiImagesResponse
} else {
result = (await fal.run(model, {
input: {
Expand All @@ -150,18 +155,62 @@ export async function resolveSegment(
enable_safety_checker:
request.settings.censorNotForAllAudiencesContent,
},
})) as FalAiImageResponse
})) as FalAiImagesResponse
}

if (request.settings.censorNotForAllAudiencesContent) {
if (result.has_nsfw_concepts.includes(true)) {
if (
Array.isArray(result.has_nsfw_concepts) &&
result.has_nsfw_concepts.includes(true)
) {
throw new Error(
`The generated content has been filtered according to your safety settings`
)
}
}

segment.assetUrl = result.images[0]?.url || ''

const imageFaceswapWorkflowModel =
request.settings.imageFaceswapWorkflow.data || ''

if (!isUsingIntegratedFaceId && imageFaceswapWorkflowModel) {
try {
const faceSwapResult = (await fal.run(imageFaceswapWorkflowModel, {
input: {
base_image_url: segment.assetUrl,
swap_image_url: request.prompts.image.identity,

sync_mode: true,
num_images: 1,
enable_safety_checker:
request.settings.censorNotForAllAudiencesContent,
},
})) as FalAiImageResponse

// note how it is
const imageResult = faceSwapResult.image?.url || ''

if (!imageResult) {
throw new Error(`the generate image is empty`)
}

if (request.settings.censorNotForAllAudiencesContent) {
if (
Array.isArray(result.has_nsfw_concepts) &&
result.has_nsfw_concepts.includes(true)
) {
throw new Error(
`The generated content has been filtered according to your safety settings`
)
}
}

segment.assetUrl = imageResult
} catch (err) {
console.error(`failed to run a face-swap using Fal.ai:`, err)
}
}
} else if (request.segment.category === ClapSegmentCategory.VIDEO) {
model = request.settings.videoGenerationWorkflow.data || ''

Expand Down Expand Up @@ -190,7 +239,10 @@ export async function resolveSegment(
})) as FalAiVideoResponse

if (request.settings.censorNotForAllAudiencesContent) {
if (result.has_nsfw_concepts.includes(true)) {
if (
Array.isArray(result.has_nsfw_concepts) &&
result.has_nsfw_concepts.includes(true)
) {
throw new Error(
`The generated content has been filtered according to your safety settings`
)
Expand Down
19 changes: 18 additions & 1 deletion packages/app/src/app/api/resolve/providers/falai/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export type FalAiImageResponse = {
export type FalAiImagesResponse = {
prompt: string
timings: { inference: number }
has_nsfw_concepts: boolean[]
Expand All @@ -7,10 +7,27 @@ export type FalAiImageResponse = {
url: string
width: number
height: number
file_name: string
file_size: string
content_type: string
}[]
}

export type FalAiImageResponse = {
prompt: string
timings: { inference: number }
has_nsfw_concepts: boolean[]
seed: number
image: {
url: string
width: number
height: number
file_name: string
file_size: string
content_type: string
}
}

export type FalAiVideoResponse = {
video: {
url: string
Expand Down
40 changes: 18 additions & 22 deletions packages/app/src/app/api/resolve/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
ClapAssetSource,
ClapWorkflowEngine,
} from '@aitube/clap'
import { TimelineSegment } from '@aitube/timeline'

import {
resolveSegmentUsingHuggingFace,
Expand All @@ -32,7 +33,7 @@ import { ResolveRequest } from '@aitube/clapper-services'
import { decodeOutput } from '@/lib/utils/decodeOutput'
import { getTypeAndExtension } from '@/lib/utils/getTypeAndExtension'
import { getMediaInfo } from '@/lib/ffmpeg/getMediaInfo'
import { TimelineSegment } from '@aitube/timeline'
import { getSegmentWorkflowProviderAndEngine } from '@/services/editors/workflow-editor/getSegmentWorkflowProviderAndEngine'

type ProviderFn = (request: ResolveRequest) => Promise<TimelineSegment>

Expand All @@ -44,34 +45,27 @@ export async function POST(req: NextRequest) {
// await throwIfInvalidToken(req.headers.get("Authorization"))
const request = (await req.json()) as ResolveRequest

const workflow: ClapWorkflow | undefined =
request.segment.category === ClapSegmentCategory.STORYBOARD
? request.settings.imageGenerationWorkflow
: request.segment.category === ClapSegmentCategory.VIDEO
? request.settings.videoGenerationWorkflow
: request.segment.category === ClapSegmentCategory.DIALOGUE
? request.settings.voiceGenerationWorkflow
: request.segment.category === ClapSegmentCategory.SOUND
? request.settings.soundGenerationWorkflow
: request.segment.category === ClapSegmentCategory.MUSIC
? request.settings.musicGenerationWorkflow
: undefined
const { workflow, provider, engine } =
getSegmentWorkflowProviderAndEngine(request)

/*
console.log(`Resolving a ${request.segment.category} segment using:`, {
workflow,
provider,
engine,
})
*/

if (!workflow) {
throw new Error(`request to /api/resolve is missing the .workflow field`)
throw new Error(`cannot resolve a segment without a valid workflow`)
}

const provider: ClapWorkflowProvider | undefined =
workflow.provider || undefined

if (!provider) {
throw new Error(`request to /api/resolve is missing the .provider field`)
if (!provider || provider === ClapWorkflowProvider.NONE) {
throw new Error(`cannot resolve a segment without a valid provider`)
}

const engine: ClapWorkflowEngine | undefined = workflow.engine || undefined

if (!engine) {
throw new Error(`request to /api/resolve is missing the .engine field`)
throw new Error(`cannot resolve a segment without a valid engine`)
}

const comfyProviders: Partial<Record<ClapWorkflowProvider, ProviderFn>> = {
Expand Down Expand Up @@ -102,6 +96,7 @@ export async function POST(req: NextRequest) {
: providers[provider] || undefined

if (!resolveSegment || typeof resolveSegment !== 'function') {
// console.log('invalid resolveSegment:', request)
throw new Error(
`Engine "${engine}" is not supported by "${provider}" yet. If you believe this is a mistake, please open a Pull Request (with working code) to fix it. Thank you!`
)
Expand All @@ -110,6 +105,7 @@ export async function POST(req: NextRequest) {
let segment = request.segment

try {
// console.log('calling resolveSegment', request)
segment = await resolveSegment(request)

// we clean-up and parse the output from all the resolvers:
Expand Down
78 changes: 42 additions & 36 deletions packages/app/src/components/tasks/useTasks.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -178,36 +178,36 @@ export const useTasks = create<TasksStore>((set, get) => ({
const status = t.task.status || 'deleted'
const progress = t.task.progress || 0

console.log(
`useTasks[${id}]: checkStatus: checking task, current status is: "${status}"`
)
// console.log(
// `useTasks[${id}]: checkStatus: checking task, current status is: "${status}"`
// )
if (
status === TaskStatus.ERROR ||
status === TaskStatus.SUCCESS ||
status === TaskStatus.DELETED ||
status === TaskStatus.CANCELLED
) {
console.log(
`useTasks[${id}]: checkStatus: status is "${status}", interrupting task loop..`
)
// console.log(
// `useTasks[${id}]: checkStatus: status is "${status}", interrupting task loop..`
// )

// this call might be redundant
if (status === TaskStatus.SUCCESS) {
get().setProgress(id, { isFinished: true })
}
resolve(status)
} else if (progress >= 100) {
console.log(
`useTasks[${id}]: checkStatus: task is completed at 100%, interrupting task loop..`
)
// console.log(
// `useTasks[${id}]: checkStatus: task is completed at 100%, interrupting task loop..`
// )
// this call might be redundant
get().setProgress(id, { isFinished: true })
// get().setStatus(TaskStatus.SUCCESS, id)
resolve(TaskStatus.SUCCESS)
} else {
console.log(
`useTasks[${id}]: checkStatus: status is "${status}", continuing task loop..`
)
// console.log(
// `useTasks[${id}]: checkStatus: status is "${status}", continuing task loop..`
// )
setTimeout(checkStatus, 1000)
}
} catch (err) {
Expand All @@ -218,15 +218,17 @@ export const useTasks = create<TasksStore>((set, get) => ({
checkStatus()
})

toast.promise<TaskStatus>(task.promise, {
loading: <TaskStatusUpdate taskId={id} />,
success: (finalStatus) => {
return finalStatus === TaskStatus.SUCCESS
? task.successMessage
: `Task ended`
},
error: 'Task aborted',
})
if (task.visibility === TaskVisibility.BACKGROUND) {
toast.promise<TaskStatus>(task.promise, {
loading: <TaskStatusUpdate taskId={id} />,
success: (finalStatus) => {
return finalStatus === TaskStatus.SUCCESS
? task.successMessage
: `Task ended`
},
error: 'Task aborted',
})
}

const { tasks } = get()
set({
Expand All @@ -243,22 +245,22 @@ export const useTasks = create<TasksStore>((set, get) => ({
}
// oh, one last thing: let's launch-and-forget the actual task

console.log(
`useTasks[${id}]: launching the task runner in the background..`
)
// console.log(
// `useTasks[${id}]: launching the task runner in the background..`
// )

// we provide to the task runner a wait to get the current status
// that wait long-running jobs will know when they have been cancelled and no longer needed
const result = await task.run(() => {
const remoteControl = get().get(id)!
const status = remoteControl?.task?.status
console.log(
`useTasks[${id}]: task runner asked for current status (which is: "${status || 'deleted'}")`
)
// console.log(
// `useTasks[${id}]: task runner asked for current status (which is: "${status || 'deleted'}")`
// )
return status || 'deleted'
})

console.log(`useTasks[${id}]: task runner ended with status: "${result}"`)
// console.log(`useTasks[${id}]: task runner ended with status: "${result}"`)
get().setProgress(id, { isFinished: true })
// get().setStatus(result, id)
}, 100)
Expand All @@ -280,21 +282,21 @@ export const useTasks = create<TasksStore>((set, get) => ({
const { tasks } = get()
const task = get().get(taskId)?.task

console.log(`useTasks[${taskId}]:setStatus("${status}")`)
// console.log(`useTasks[${taskId}]:setStatus("${status}")`)
if (task) {
console.log(
`useTasks[${taskId}]:setStatus("${status}") -> setting one task to ${status}`
)
// console.log(
// `useTasks[${taskId}]:setStatus("${status}") -> setting one task to ${status}`
// )
set({
tasks: {
...tasks,
[task.id]: { ...task, status: statusTransition(task.status, status) },
},
})
} else {
console.log(
`useTasks[${taskId}]:setStatus("${status}") -> setting all tasks to ${status}`
)
// console.log(
// `useTasks[${taskId}]:setStatus("${status}") -> setting all tasks to ${status}`
// )
const newTasks = {} as Record<string, Task>
for (const [id, t] of Object.entries(tasks)) {
newTasks[id] = { ...t, status: statusTransition(t.status, status) }
Expand Down Expand Up @@ -379,12 +381,16 @@ export const useTasks = create<TasksStore>((set, get) => ({
get().setStatus(TaskStatus.SUCCESS, taskId)
},
fail: (taskId: string, reason?: string) => {
const message = reason || 'unknown failure'

get().setProgress(taskId, {
message: reason || 'unknown failure',
message,
isFinished: true,
hasFailed: true,
})
get().setStatus(TaskStatus.ERROR, taskId)

toast.error(message)
},
cancel: (taskId?: string) => {
get().setStatus(TaskStatus.CANCELLED, taskId)
Expand Down
Loading

0 comments on commit 3e6b235

Please sign in to comment.