diff --git a/src/app/api/resolve/providers/stabilityai/generateVideo.ts b/src/app/api/resolve/providers/stabilityai/generateVideo.ts index b3122000..4d97c5d9 100644 --- a/src/app/api/resolve/providers/stabilityai/generateVideo.ts +++ b/src/app/api/resolve/providers/stabilityai/generateVideo.ts @@ -1,23 +1,40 @@ -import { sleep } from '@/lib/utils/sleep' +import { base64DataUriToBlob } from '@/lib/utils/base64DataUriToBlob' import { ResolveRequest } from '@aitube/clapper-services' +import sharp from 'sharp' + +const TAG = `StabilityAI.generateVideo` + +type StabilityAIVImageToVideoStartGenerationResponse = { + id: string + name?: string + errors?: string[] +} + +enum StabilityAIVImageToVideoFetchhGenerationFinishReason { + SUCCESS = 'SUCCESS', + CONTENT_FILTERED = 'CONTENT_FILTERED', +} + +type StabilityAIVImageToVideoFetchGenerationResponse = { + video: string + finish_reason: StabilityAIVImageToVideoFetchhGenerationFinishReason + seed: number + errors?: string[] +} export async function generateVideo(request: ResolveRequest): Promise { if (!request.settings.stabilityAiApiKey) { - throw new Error( - `StabilityAI.generateVideo: cannot generate without a valid stabilityAiApiKey` - ) + throw new Error(`${TAG}: cannot generate without a valid stabilityAiApiKey`) } if (!request.settings.videoGenerationModel) { throw new Error( - `StabilityAI.generateVideo: cannot generate without a valid videoGenerationModel` + `${TAG}: cannot generate without a valid videoGenerationModel` ) } if (!request.prompts.video.image) { - throw new Error( - `StabilityAI.generateVideo: cannot generate without a valid image input` - ) + throw new Error(`${TAG}: cannot generate without a valid image input`) } // what's cool about the ultra model is its capacity to take in @@ -30,14 +47,13 @@ export async function generateVideo(request: ResolveRequest): Promise { // convey a sky that was blue and green, but more green than blue. const body = new FormData() - // Supported Formats: jpeg, png // Supported Dimensions: 1024x576, 576x1024, 768x768 // "Please ensure that the source image is in the correct format and dimensions" - body.set('image', `${request.prompts.video.image || ''}`) + body.set('image', await getRequestImage(request)) - const response = (await fetch( + const response = await fetch( `https://api.stability.ai/v2beta/image-to-video`, { method: 'POST', @@ -47,52 +63,120 @@ export async function generateVideo(request: ResolveRequest): Promise { body, cache: 'no-store', } - )) as unknown as { data: { id: number } } - - const generationId = response?.data?.id - if (!generationId) { - throw new Error(`StabilityAI failed to give us a valid response.data.id`) + ) + + if (response.status == 200) { + const { id }: StabilityAIVImageToVideoStartGenerationResponse = + await response.json() + console.log(TAG, `Generation ID: ${id}`) + const result = await pollGenerationResult( + id, + request.settings.stabilityAiApiKey + ) + console.log(TAG, 'Video was successfully generated.', result.length) + return result + } else { + const { errors }: StabilityAIVImageToVideoStartGenerationResponse = + await response.json() + if (errors) { + throw new Error(`${TAG}: ${errors.join('\n')}`) + } + throw new Error(`${TAG}: Unexpected error`) } +} - console.log('Generation ID:', generationId) - - let pollingCount = 0 - do { - // This is normally a fast model, so let's check every 4 seconds - await sleep(10000) - - const res = await fetch( - `https://api.stability.ai/v2beta/image-to-video/result/${generationId}`, - { - method: 'GET', - headers: { - Authorization: `Bearer ${request.settings.stabilityAiApiKey}`, - Accept: 'video/*', // Use 'application/json' to receive base64 encoded JSON - }, - cache: 'no-store', - } +/** + * Extracts the image from the request and resizes + * it based on the supported dimensions of StabilityAI + */ +async function getRequestImage(request: ResolveRequest) { + const supportedDimensions = [`1024x576`, `576x1024`, `768x768`] + let imageBlob = base64DataUriToBlob(`${request.prompts.video.image || ''}`) + const imageBuffer = Buffer.from(await imageBlob.arrayBuffer()) + const { width, height } = await sharp(imageBuffer).metadata() + const dimensions = `${width}x${height}` + if (!(dimensions in supportedDimensions)) { + console.log( + TAG, + `Unsupported dimensions ${width}x${height}, resizing to 1024x576 ...` ) + const resizedImageBuffer = await sharp(imageBuffer) + .resize({ + width: 1024, + height: 576, + fit: 'cover', + position: 'center', + }) + .toBuffer() + imageBlob = new Blob([resizedImageBuffer], { type: 'image/jpeg' }) + } + return imageBlob +} - if (res.status === 200) { +async function pollGenerationResult( + generationId: string, + apiKey: string, + maxPollingCount = 40, + intervalMs = 10000 +): Promise { + console.log(TAG, `Polling generation result width id = ${generationId} ...`) + return new Promise((resolve, reject) => { + let pollingCount = 0 + const intervalId = setInterval(async () => { try { - const response = (await res.json()) as any - const errors = `${response?.errors || ''}` - if (errors) { - throw new Error(errors) + const res = await fetch( + `https://api.stability.ai/v2beta/image-to-video/result/${generationId}`, + { + method: 'GET', + headers: { + Authorization: `Bearer ${apiKey}`, + Accept: 'application/json; type=video/mp4', // Use 'video/*' to receive raw bytes + }, + cache: 'no-store', + } + ) + + if (res.status === 202) { + return pollingCount++ } - return response.output.pop() - } catch (err) { - console.error('res.json() error:', err) - } - } - - pollingCount++ - // To prevent indefinite polling, we can stop after a certain number - if (pollingCount >= 40) { - throw new Error('Request timed out.') - } - } while (true) + try { + const { + video, + errors, + finish_reason, + }: StabilityAIVImageToVideoFetchGenerationResponse = await res.json() + if (res.status > 200) { + throw new Error(errors?.join('\n')) + } + if ( + finish_reason != + StabilityAIVImageToVideoFetchhGenerationFinishReason.SUCCESS + ) { + throw new Error('Content filtered') + } + resolve(`data:video/mp4;base64,${video}`) + } catch (err) { + console.error(TAG, err) + if (res.status < 500) { + reject(err) + } + } finally { + if (res.status < 500) { + return clearInterval(intervalId) + } else { + pollingCount++ + } + } - throw new Error('finish me') + if (pollingCount >= maxPollingCount) { + clearInterval(intervalId) + reject(new Error(`${TAG}: Request timed out.`)) + } + } catch (error) { + clearInterval(intervalId) + reject(error) + } + }, intervalMs) + }) } diff --git a/src/app/api/resolve/providers/stabilityai/index.ts b/src/app/api/resolve/providers/stabilityai/index.ts index 3e6159de..7633e11d 100644 --- a/src/app/api/resolve/providers/stabilityai/index.ts +++ b/src/app/api/resolve/providers/stabilityai/index.ts @@ -2,6 +2,7 @@ import { ClapSegmentCategory } from '@aitube/clap' import { TimelineSegment } from '@aitube/timeline' import { ResolveRequest } from '@aitube/clapper-services' import { generateImage } from './generateImage' +import { generateVideo } from './generateVideo' export async function resolveSegment( request: ResolveRequest @@ -12,11 +13,10 @@ export async function resolveSegment( const segment = request.segment - // for doc see: - // https://fal.ai/models/fal-ai/fast-sdxl/api - if (request.segment.category === ClapSegmentCategory.STORYBOARD) { segment.assetUrl = await generateImage(request) + } else if (request.segment.category === ClapSegmentCategory.VIDEO) { + segment.assetUrl = await generateVideo(request) } else { throw new Error( `Clapper doesn't support ${request.segment.category} generation for provider "Stability.ai". Please open a pull request with (working code) to solve this!`