Skip to content

Commit

Permalink
Merge pull request #25 from devniel/stabilityai-image-to-video
Browse files Browse the repository at this point in the history
feat: generate videos with stabilityai
  • Loading branch information
jbilcke-hf authored Jul 21, 2024
2 parents b99b820 + 2ef024a commit d3fb9b4
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 54 deletions.
186 changes: 135 additions & 51 deletions src/app/api/resolve/providers/stabilityai/generateVideo.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
import { sleep } from '@/lib/utils/sleep'
import { base64DataUriToBlob } from '@/lib/utils/base64DataUriToBlob'
import { ResolveRequest } from '@aitube/clapper-services'
import sharp from 'sharp'

const TAG = `StabilityAI.generateVideo`

type StabilityAIVImageToVideoStartGenerationResponse = {
id: string
name?: string
errors?: string[]
}

enum StabilityAIVImageToVideoFetchhGenerationFinishReason {
SUCCESS = 'SUCCESS',
CONTENT_FILTERED = 'CONTENT_FILTERED',
}

type StabilityAIVImageToVideoFetchGenerationResponse = {
video: string
finish_reason: StabilityAIVImageToVideoFetchhGenerationFinishReason
seed: number
errors?: string[]
}

export async function generateVideo(request: ResolveRequest): Promise<string> {
if (!request.settings.stabilityAiApiKey) {
throw new Error(
`StabilityAI.generateVideo: cannot generate without a valid stabilityAiApiKey`
)
throw new Error(`${TAG}: cannot generate without a valid stabilityAiApiKey`)
}

if (!request.settings.videoGenerationModel) {
throw new Error(
`StabilityAI.generateVideo: cannot generate without a valid videoGenerationModel`
`${TAG}: cannot generate without a valid videoGenerationModel`
)
}

if (!request.prompts.video.image) {
throw new Error(
`StabilityAI.generateVideo: cannot generate without a valid image input`
)
throw new Error(`${TAG}: cannot generate without a valid image input`)
}

// what's cool about the ultra model is its capacity to take in
Expand All @@ -30,14 +47,13 @@ export async function generateVideo(request: ResolveRequest): Promise<string> {
// convey a sky that was blue and green, but more green than blue.

const body = new FormData()

// Supported Formats: jpeg, png
// Supported Dimensions: 1024x576, 576x1024, 768x768

// "Please ensure that the source image is in the correct format and dimensions"
body.set('image', `${request.prompts.video.image || ''}`)
body.set('image', await getRequestImage(request))

const response = (await fetch(
const response = await fetch(
`https://api.stability.ai/v2beta/image-to-video`,
{
method: 'POST',
Expand All @@ -47,52 +63,120 @@ export async function generateVideo(request: ResolveRequest): Promise<string> {
body,
cache: 'no-store',
}
)) as unknown as { data: { id: number } }

const generationId = response?.data?.id
if (!generationId) {
throw new Error(`StabilityAI failed to give us a valid response.data.id`)
)

if (response.status == 200) {
const { id }: StabilityAIVImageToVideoStartGenerationResponse =
await response.json()
console.log(TAG, `Generation ID: ${id}`)
const result = await pollGenerationResult(
id,
request.settings.stabilityAiApiKey
)
console.log(TAG, 'Video was successfully generated.', result.length)
return result
} else {
const { errors }: StabilityAIVImageToVideoStartGenerationResponse =
await response.json()
if (errors) {
throw new Error(`${TAG}: ${errors.join('\n')}`)
}
throw new Error(`${TAG}: Unexpected error`)
}
}

console.log('Generation ID:', generationId)

let pollingCount = 0
do {
// This is normally a fast model, so let's check every 4 seconds
await sleep(10000)

const res = await fetch(
`https://api.stability.ai/v2beta/image-to-video/result/${generationId}`,
{
method: 'GET',
headers: {
Authorization: `Bearer ${request.settings.stabilityAiApiKey}`,
Accept: 'video/*', // Use 'application/json' to receive base64 encoded JSON
},
cache: 'no-store',
}
/**
* Extracts the image from the request and resizes
* it based on the supported dimensions of StabilityAI
*/
async function getRequestImage(request: ResolveRequest) {
const supportedDimensions = [`1024x576`, `576x1024`, `768x768`]
let imageBlob = base64DataUriToBlob(`${request.prompts.video.image || ''}`)
const imageBuffer = Buffer.from(await imageBlob.arrayBuffer())
const { width, height } = await sharp(imageBuffer).metadata()
const dimensions = `${width}x${height}`
if (!(dimensions in supportedDimensions)) {
console.log(
TAG,
`Unsupported dimensions ${width}x${height}, resizing to 1024x576 ...`
)
const resizedImageBuffer = await sharp(imageBuffer)
.resize({
width: 1024,
height: 576,
fit: 'cover',
position: 'center',
})
.toBuffer()
imageBlob = new Blob([resizedImageBuffer], { type: 'image/jpeg' })
}
return imageBlob
}

if (res.status === 200) {
async function pollGenerationResult(
generationId: string,
apiKey: string,
maxPollingCount = 40,
intervalMs = 10000
): Promise<string> {
console.log(TAG, `Polling generation result width id = ${generationId} ...`)
return new Promise((resolve, reject) => {
let pollingCount = 0
const intervalId = setInterval(async () => {
try {
const response = (await res.json()) as any
const errors = `${response?.errors || ''}`
if (errors) {
throw new Error(errors)
const res = await fetch(
`https://api.stability.ai/v2beta/image-to-video/result/${generationId}`,
{
method: 'GET',
headers: {
Authorization: `Bearer ${apiKey}`,
Accept: 'application/json; type=video/mp4', // Use 'video/*' to receive raw bytes
},
cache: 'no-store',
}
)

if (res.status === 202) {
return pollingCount++
}
return response.output.pop()
} catch (err) {
console.error('res.json() error:', err)
}
}

pollingCount++

// To prevent indefinite polling, we can stop after a certain number
if (pollingCount >= 40) {
throw new Error('Request timed out.')
}
} while (true)
try {
const {
video,
errors,
finish_reason,
}: StabilityAIVImageToVideoFetchGenerationResponse = await res.json()
if (res.status > 200) {
throw new Error(errors?.join('\n'))
}
if (
finish_reason !=
StabilityAIVImageToVideoFetchhGenerationFinishReason.SUCCESS
) {
throw new Error('Content filtered')
}
resolve(`data:video/mp4;base64,${video}`)
} catch (err) {
console.error(TAG, err)
if (res.status < 500) {
reject(err)
}
} finally {
if (res.status < 500) {
return clearInterval(intervalId)
} else {
pollingCount++
}
}

throw new Error('finish me')
if (pollingCount >= maxPollingCount) {
clearInterval(intervalId)
reject(new Error(`${TAG}: Request timed out.`))
}
} catch (error) {
clearInterval(intervalId)
reject(error)
}
}, intervalMs)
})
}
6 changes: 3 additions & 3 deletions src/app/api/resolve/providers/stabilityai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { ClapSegmentCategory } from '@aitube/clap'
import { TimelineSegment } from '@aitube/timeline'
import { ResolveRequest } from '@aitube/clapper-services'
import { generateImage } from './generateImage'
import { generateVideo } from './generateVideo'

export async function resolveSegment(
request: ResolveRequest
Expand All @@ -12,11 +13,10 @@ export async function resolveSegment(

const segment = request.segment

// for doc see:
// https://fal.ai/models/fal-ai/fast-sdxl/api

if (request.segment.category === ClapSegmentCategory.STORYBOARD) {
segment.assetUrl = await generateImage(request)
} else if (request.segment.category === ClapSegmentCategory.VIDEO) {
segment.assetUrl = await generateVideo(request)
} else {
throw new Error(
`Clapper doesn't support ${request.segment.category} generation for provider "Stability.ai". Please open a pull request with (working code) to solve this!`
Expand Down

0 comments on commit d3fb9b4

Please sign in to comment.