diff --git a/packages/app/src/app/api/resolve/providers/comfyui/index.ts b/packages/app/src/app/api/resolve/providers/comfyui/index.ts index 9b2c3d19..7bb9a0b2 100644 --- a/packages/app/src/app/api/resolve/providers/comfyui/index.ts +++ b/packages/app/src/app/api/resolve/providers/comfyui/index.ts @@ -18,6 +18,7 @@ import { import { getWorkflowInputValues } from '../getWorkflowInputValues' import { decodeOutput } from '@/lib/utils/decodeOutput' +import { ComfyUIWorkflowApiUtils } from './utils' export async function resolveSegment( request: ResolveRequest @@ -54,50 +55,21 @@ export async function resolveSegment( request.settings.imageGenerationWorkflow.data ) - const txt2ImgPrompt = new PromptBuilder( - comfyApiWorkflow, - // TODO: this list should be detect/filled automatically (see line 86) - [ - 'positive', - 'negative', - 'checkpoint', - 'seed', - 'batch', - 'step', - 'cfg', - 'sampler', - 'sheduler', - 'width', - 'height', - ], - // TODO: this list should be detect/filled automatically (see line 86) - ['images'] - ) - // TODO: those input sets should be detect/filled automatically (see line 86) - .setInputNode('checkpoint', '4.inputs.ckpt_name') - .setInputNode('seed', '3.inputs.seed') - .setInputNode('batch', '5.inputs.batch_size') - .setInputNode('negative', '7.inputs.text') - .setInputNode('positive', '6.inputs.text') - .setInputNode('cfg', '3.inputs.cfg') - .setInputNode('sampler', '3.inputs.sampler_name') - .setInputNode('sheduler', '3.inputs.scheduler') - .setInputNode('step', '3.inputs.steps') - .setInputNode('width', '5.inputs.width') - .setInputNode('height', '5.inputs.height') - .setOutputNode('images', '9') + const txt2ImgPrompt = new ComfyUIWorkflowApiUtils( + comfyApiWorkflow + ).createPromptBuilder() const workflow = txt2ImgPrompt // TODO: this mapping should be detect/filled automatically (see line 86) - .input('checkpoint', 'SDXL/realvisxlV40_v40LightningBakedvae.safetensors') + .input('ckpt_name', 'SDXL/realvisxlV40_v40LightningBakedvae.safetensors') .input('seed', generateSeed()) - .input('step', 6) + .input('steps', 6) .input('cfg', 1) - .input('sampler', 'dpmpp_2m_sde_gpu') - .input('sheduler', 'sgm_uniform') + .input('sampler_name', 'dpmpp_2m_sde_gpu') + .input('scheduler', 'sgm_uniform') .input('width', request.meta.width) .input('height', request.meta.height) - .input('batch', 1) + .input('batch_size', 1) .input('positive', request.prompts.image.positive) // for the moment we only have non-working "mock" sample code, @@ -154,7 +126,7 @@ export async function resolveSegment( throw new Error(`failed to run the pipeline (no output)`) } - const imagePaths = rawOutput.images?.images.map((img: any) => + const imagePaths = rawOutput.output?.images.map((img: any) => api.getPathImage(img) ) diff --git a/packages/app/src/app/api/resolve/providers/comfyui/utils.spec.ts b/packages/app/src/app/api/resolve/providers/comfyui/utils.spec.ts new file mode 100644 index 00000000..51097c9a --- /dev/null +++ b/packages/app/src/app/api/resolve/providers/comfyui/utils.spec.ts @@ -0,0 +1,442 @@ +import { expect, test } from 'vitest' +import { ComfyUIWorkflowApiUtils } from './utils' + +// Default workflow used by ComfyUI, downloaded for API +const workflowRaw = { + '3': { + inputs: { + seed: 156680208700286, + steps: 20, + cfg: 8, + sampler_name: 'euler', + scheduler: 'normal', + denoise: 1, + model: ['4', 0], + positive: ['6', 0], + negative: ['7', 0], + latent_image: ['5', 0], + }, + class_type: 'KSampler', + _meta: { + title: 'KSampler', + }, + }, + '4': { + inputs: { + ckpt_name: 'v1-5-pruned-emaonly.ckpt', + }, + class_type: 'CheckpointLoaderSimple', + _meta: { + title: 'Load Checkpoint', + }, + }, + '5': { + inputs: { + width: 512, + height: 512, + batch_size: 1, + }, + class_type: 'EmptyLatentImage', + _meta: { + title: 'Empty Latent Image', + }, + }, + '6': { + inputs: { + text: 'beautiful scenery nature glass bottle landscape, , purple galaxy bottle,', + clip: ['4', 1], + }, + class_type: 'CLIPTextEncode', + _meta: { + title: 'CLIP Text Encode (Prompt)', + }, + }, + '7': { + inputs: { + text: 'text, watermark', + clip: ['4', 1], + }, + class_type: 'CLIPTextEncode', + _meta: { + title: 'CLIP Text Encode (Prompt)', + }, + }, + '8': { + inputs: { + samples: ['3', 0], + vae: ['4', 2], + }, + class_type: 'VAEDecode', + _meta: { + title: 'VAE Decode', + }, + }, + '9': { + inputs: { + filename_prefix: 'ComfyUI', + images: ['8', 0], + }, + class_type: 'SaveImage', + _meta: { + title: 'Save Image', + }, + }, +} + +// Example workflow object using @clapper/- tokens +const workflowRawWithTokens = { + '3': { + inputs: { + seed: 156680208700286, + steps: 20, + cfg: 8, + sampler_name: 'euler', + scheduler: 'normal', + denoise: 1, + model: ['4', 0], + positive: ['6', 0], + negative: ['7', 0], + latent_image: ['5', 0], + }, + class_type: 'KSampler', + _meta: { + title: 'KSampler', + }, + }, + '4': { + inputs: { + ckpt_name: 'v1-5-pruned-emaonly.ckpt', + }, + class_type: 'CheckpointLoaderSimple', + _meta: { + title: 'Load Checkpoint', + }, + }, + '5': { + inputs: { + width: 512, + height: 512, + batch_size: 1, + }, + class_type: 'EmptyLatentImage', + _meta: { + title: 'Empty Latent Image', + }, + }, + '6': { + inputs: { + text: '@clapper/prompt', + clip: ['4', 1], + }, + class_type: 'CLIPTextEncode', + _meta: { + title: 'CLIP Text Encode (Prompt)', + }, + }, + '7': { + inputs: { + text: '@clapper/negative', + clip: ['4', 1], + }, + class_type: 'CLIPTextEncode', + _meta: { + title: 'CLIP Text Encode (Prompt)', + }, + }, + '8': { + inputs: { + samples: ['3', 0], + vae: ['4', 2], + }, + class_type: 'VAEDecode', + _meta: { + title: 'VAE Decode', + }, + }, + '9': { + inputs: { + filename_prefix: 'ComfyUI', + images: ['8', 0], + }, + class_type: 'SaveImage', + _meta: { + title: 'Save Image', + }, + }, +} + +test('should return all nodes that have inputs', () => { + const nodesWithInputs = new ComfyUIWorkflowApiUtils( + workflowRaw + ).getNodesWithInputs() + + // Expect nodes 3, 4, 5, 6, 7, 8, and 9 to have inputs + expect(nodesWithInputs).toHaveLength(7) + expect(nodesWithInputs).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: '3' }), + expect.objectContaining({ id: '4' }), + expect.objectContaining({ id: '5' }), + expect.objectContaining({ id: '6' }), + expect.objectContaining({ id: '7' }), + expect.objectContaining({ id: '8' }), + expect.objectContaining({ id: '9' }), + ]) + ) +}) + +test('should return the correct output node', () => { + const outputNode = new ComfyUIWorkflowApiUtils(workflowRaw).getOutputNode() + + expect(outputNode).toEqual({ + id: '9', + inputs: { + filename_prefix: 'ComfyUI', + images: ['8', 0], + }, + class_type: 'SaveImage', + _meta: { + title: 'Save Image', + }, + }) +}) + +test('should build the correct graph from the workflow', () => { + const { adjList, dependencyList, inDegree } = new ComfyUIWorkflowApiUtils( + workflowRaw + ).getGraphData() + + expect(adjList['3']).toEqual(['8']) + expect(adjList['8']).toEqual(['9']) + expect(inDegree['3']).toBe(4) + expect(dependencyList['3']).toEqual([ + { + from: '4', + inputName: 'model', + }, + { + from: '6', + inputName: 'positive', + }, + { from: '7', inputName: 'negative' }, + { from: '5', inputName: 'latent_image' }, + ]) + expect(inDegree['9']).toBe(1) +}) + +test('should return the correct inputs by node id', () => { + const workflow = new ComfyUIWorkflowApiUtils(workflowRaw) + + expect(workflow.getInputsByNodeId('3')).toEqual([ + { + type: 'number', + name: 'seed', + value: 156680208700286, + key: '3.inputs.seed', + nodeId: '3', + }, + { + type: 'number', + name: 'steps', + value: 20, + key: '3.inputs.steps', + nodeId: '3', + }, + { type: 'number', name: 'cfg', value: 8, key: '3.inputs.cfg', nodeId: '3' }, + { + type: 'string', + name: 'sampler_name', + value: 'euler', + key: '3.inputs.sampler_name', + nodeId: '3', + }, + { + type: 'string', + name: 'scheduler', + value: 'normal', + key: '3.inputs.scheduler', + nodeId: '3', + }, + { + type: 'number', + name: 'denoise', + value: 1, + key: '3.inputs.denoise', + nodeId: '3', + }, + ]) + + expect(workflow.getInputsByNodeId('4')).toEqual([ + { + type: 'string', + name: 'ckpt_name', + value: 'v1-5-pruned-emaonly.ckpt', + key: '4.inputs.ckpt_name', + nodeId: '4', + }, + ]) + + expect(workflow.getInputsByNodeId('5')).toEqual([ + { + type: 'number', + name: 'width', + value: 512, + key: '5.inputs.width', + nodeId: '5', + }, + { + type: 'number', + name: 'height', + value: 512, + key: '5.inputs.height', + nodeId: '5', + }, + { + type: 'number', + name: 'batch_size', + value: 1, + key: '5.inputs.batch_size', + nodeId: '5', + }, + ]) + + expect(workflow.getInputsByNodeId('6')).toEqual([ + { + type: 'string', + name: 'text', + value: + 'beautiful scenery nature glass bottle landscape, , purple galaxy bottle,', + key: '6.inputs.text', + nodeId: '6', + }, + ]) + + const nonExistentNodeInputs = workflow.getInputsByNodeId('99') + expect(nonExistentNodeInputs).toBeNull() +}) + +test('should detect the correct main inputs', () => { + const mainInputs = new ComfyUIWorkflowApiUtils(workflowRaw).detectMainInputs() + + expect(mainInputs).toEqual([ + { + nodeId: '6', + name: 'text', + value: + 'beautiful scenery nature glass bottle landscape, , purple galaxy bottle,', + }, + { nodeId: '7', name: 'text', value: 'text, watermark' }, + ]) + + expect(mainInputs).not.toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'width' }), + expect.objectContaining({ name: 'cfg' }), + ]) + ) +}) + +test('should detect the correct positive and negative prompt inputs', () => { + const workflow = new ComfyUIWorkflowApiUtils(workflowRaw) + const positivePrompts = workflow.detectPositivePromptInput() + const negativePrompts = workflow.detectNegativePromptInput() + + expect(positivePrompts).toEqual([ + { + nodeId: '6', + name: 'text', + type: 'positive', + value: + 'beautiful scenery nature glass bottle landscape, , purple galaxy bottle,', + }, + ]) + + expect(negativePrompts).toEqual([ + { nodeId: '7', name: 'text', type: 'negative', value: 'text, watermark' }, + ]) + + expect(positivePrompts).not.toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'steps' })]) + ) + expect(negativePrompts).not.toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'cfg' })]) + ) +}) + +test('should detect the correct positive and negative prompt inputs using clapper tokens', () => { + const workflow = new ComfyUIWorkflowApiUtils(workflowRawWithTokens) + const positivePrompts = workflow.detectPositivePromptInput() + const negativePrompts = workflow.detectNegativePromptInput() + + expect(positivePrompts).toEqual([ + { + nodeId: '6', + type: 'positive', + name: 'text', + value: '@clapper/prompt', + }, + ]) + + expect(negativePrompts).toEqual([ + { nodeId: '7', type: 'negative', name: 'text', value: '@clapper/negative' }, + ]) + + expect(positivePrompts).not.toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'steps' })]) + ) + expect(negativePrompts).not.toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'cfg' })]) + ) +}) + +test('should create the PromptBuilder', () => { + const promptBuilder = new ComfyUIWorkflowApiUtils( + workflowRaw + ).createPromptBuilder() + expect(promptBuilder.mapOutputKeys).toEqual({ + output: '9', + }) + expect(promptBuilder.mapInputKeys).toEqual({ + seed: '3.inputs.seed', + steps: '3.inputs.steps', + cfg: '3.inputs.cfg', + sampler_name: '3.inputs.sampler_name', + scheduler: '3.inputs.scheduler', + denoise: '3.inputs.denoise', + ckpt_name: '4.inputs.ckpt_name', + width: '5.inputs.width', + height: '5.inputs.height', + batch_size: '5.inputs.batch_size', + text: '7.inputs.text', + filename_prefix: '9.inputs.filename_prefix', + positive: '6.inputs.text', + negative: '7.inputs.text', + }) + expect(promptBuilder.prompt).toEqual(workflowRaw) +}) + +/** + * Error handling + */ + +// TODO: More corrupted workflows +const workflowRawWithCycles = { + a: { + inputs: { + text: ['b', 0], + }, + }, + b: { + inputs: { + text: ['a', 0], + }, + }, +} + +test('should fail if workflow has cycles', () => { + expect(() => { + new ComfyUIWorkflowApiUtils(workflowRawWithCycles) + }).toThrow( + 'The provided workflow has cycles, impossible to get the output node.' + ) +}) diff --git a/packages/app/src/app/api/resolve/providers/comfyui/utils.ts b/packages/app/src/app/api/resolve/providers/comfyui/utils.ts new file mode 100644 index 00000000..86d31c71 --- /dev/null +++ b/packages/app/src/app/api/resolve/providers/comfyui/utils.ts @@ -0,0 +1,374 @@ +import { PromptBuilder } from '@saintno/comfyui-sdk' + +type NodeRawData = { + inputs?: Record + class_type?: string + _meta?: { + title: string + } +} + +type NodeData = NodeRawData & { + id: string +} + +type ComfyUIWorkflowApiRaw = Record + +type INPUT_TYPES = 'string' | 'number' + +type NodeInput = { + // Infered primitive type of the input based on its value + type: INPUT_TYPES + name: string + value: any + key: string + nodeId: string +} + +type PromptClapperInput = { + // Infered clapper input type based on input value, input node relationships, etc + type?: 'positive' | 'negative' + nodeId: string + name: string + value: any +} + +/** + * Utils to query ComfyUI workflow-api nodes data + */ +export class ComfyUIWorkflowApiUtils { + private workflow: ComfyUIWorkflowApiRaw + private adjList: Record + private dependencyList: Record + private dependantList: Record + private inDegree: Record + + constructor(workflow: ComfyUIWorkflowApiRaw) { + this.workflow = workflow + const { adjList, dependencyList, dependantList, inDegree } = + this.buildGraphData() + this.adjList = adjList + this.dependencyList = dependencyList + this.dependantList = dependantList + this.inDegree = inDegree + const hasCycles = this.detectCycles() + if (hasCycles) { + throw new Error( + 'The provided workflow has cycles, impossible to get the output node.' + ) + } + } + + /** + * Create a graph structure in an adjacent list + * representation with additional data arrays + * for dev purposes. + * @param workflow + * @returns + */ + private buildGraphData() { + const adjList: Record = {} + const dependencyList: Record< + string, + { from: string; inputName: string }[] + > = {} + const dependantList: Record = + {} + const inDegree: Record = {} + + for (const nodeId of Object.keys(this.workflow)) { + adjList[nodeId] = [] + dependencyList[nodeId] = [] + dependantList[nodeId] = [] + inDegree[nodeId] = 0 + } + + for (const [nodeId, nodeData] of Object.entries(this.workflow)) { + const completeNodeData: NodeData = { id: nodeId, ...nodeData } + if (completeNodeData.inputs) { + for (const [inputName, value] of Object.entries( + completeNodeData.inputs + )) { + if (Array.isArray(value)) { + const dependency = value[0] as string + adjList[dependency].push(nodeId) + dependencyList[nodeId].push({ from: dependency, inputName }) + dependantList[dependency].push({ to: nodeId, inputName }) + inDegree[nodeId] += 1 + } + } + } + } + + return { adjList, dependencyList, dependantList, inDegree } + } + + private detectCycles(): boolean { + const visited: Record = {} + const recursionStack: Record = {} + const dfs = (nodeId: string): boolean => { + if (!visited[nodeId]) { + visited[nodeId] = true + recursionStack[nodeId] = true + for (const neighbor of this.adjList[nodeId]) { + if (!visited[neighbor] && dfs(neighbor)) { + return true + } else if (recursionStack[neighbor]) { + return true + } + } + } + + recursionStack[nodeId] = false + return false + } + + for (const nodeId of Object.keys(this.adjList)) { + if (dfs(nodeId)) { + return true + } + } + + return false + } + + getGraphData() { + const { adjList, dependencyList, dependantList, inDegree } = this + return { adjList, dependencyList, dependantList, inDegree } + } + + /** + * Get all nodes that have inputs. + */ + getNodesWithInputs(): NodeData[] { + const nodesWithInputs: NodeData[] = [] + for (const [nodeId, nodeData] of Object.entries(this.workflow)) { + if (nodeData.inputs && Object.keys(nodeData.inputs).length > 0) { + const completeNodeData: NodeData = { id: nodeId, ...nodeData } + nodesWithInputs.push(completeNodeData) + } + } + return nodesWithInputs + } + + /** + * Get all inputs in the workflow. + */ + getInputs(): NodeInput[] { + const nodesWithInputs = this.getNodesWithInputs() + let inputs: any[] = [] + + for (const node of nodesWithInputs) { + const inputSchemas = this.getInputsByNodeId(node.id) + if (inputSchemas?.length) { + inputs = inputs.concat(...inputSchemas) + } + } + + return inputs + } + + /** + * Topological sort of the graph (Kahn's Algorithm) to get the final output node. + * TODO: multiple outputs. + */ + getOutputNode(): NodeData | null { + const { adjList, inDegree } = this + const queue: string[] = [] + const sortedOrder: string[] = [] + + for (const nodeId of Object.keys(inDegree)) { + if (inDegree[nodeId] === 0) { + queue.push(nodeId) + } + } + + while (queue.length > 0) { + const currentNode = queue.shift()! + sortedOrder.push(currentNode) + for (const neighbor of adjList[currentNode]) { + inDegree[neighbor] -= 1 + if (inDegree[neighbor] === 0) { + queue.push(neighbor) + } + } + } + + // Last node in sortedOrder is the output node + // TODO: handle multiple outputs + if (sortedOrder.length === Object.keys(this.workflow).length) { + const outputNodeId = sortedOrder[sortedOrder.length - 1] + return { id: outputNodeId, ...this.workflow[outputNodeId] } + } else { + // If there are cycles, fail + throw new Error( + 'The provided workflow has cycles, impossible to get the output node.' + ) + } + } + + /** + * Get all value inputs of a given node in the workflow. + * Ignore input connections (e.g. inputs with value ['3', 0]) + * @param nodeId the id of the node + */ + getInputsByNodeId(nodeId: string): NodeInput[] | null { + const nodeData = this.workflow[nodeId] + if (!nodeData || !nodeData.inputs) { + return null + } + + const inputs: NodeInput[] = [] + + for (const [name, value] of Object.entries(nodeData.inputs)) { + if (Array.isArray(value)) continue + + // TODO: Handle more types + let inputType: INPUT_TYPES = + typeof value === 'string' ? 'string' : 'number' + + inputs.push({ + type: inputType, + name: name, + value: value, + key: `${nodeId}.inputs.${name}`, + nodeId, + }) + } + + return inputs.length > 0 ? inputs : null + } + + /** + * Search for the main inputs for Clapper + * e.g. prompt, negative prompt + */ + detectMainInputs(): PromptClapperInput[] { + const nodesWithInputs = this.getNodesWithInputs() + const mainInputs: PromptClapperInput[] = [] + + for (const node of nodesWithInputs) { + const { id: nodeId, inputs, class_type, _meta } = node + const nodeInputs = this.getInputsByNodeId(node.id) + + if (nodeInputs) { + for (const nodeInput of nodeInputs) { + // Based on the type or input name + const isStringInput = nodeInput.type === 'string' + const nameContainsTextOrPrompt = + nodeInput.name.includes('text') || nodeInput.name.includes('prompt') + // Based on the node type + const classIsCLIPTextEncode = class_type === 'CLIPTextEncode' + // Based on the node title + const titleContainsPrompt = _meta?.title + ?.toLowerCase() + .includes('prompt') + // Based on Clapper string tokens + const hasClapperTokens = + isStringInput && + nodeInput.value?.toLowerCase().includes('@clapper/') + + if ( + (isStringInput && nameContainsTextOrPrompt) || + classIsCLIPTextEncode || + titleContainsPrompt || + hasClapperTokens + ) { + mainInputs.push({ + name: nodeInput.name, + value: nodeInput.value, + nodeId: nodeInput.nodeId, + }) + } + } + } + } + + return mainInputs + } + + /** + * Detect positive prompt inputs in the workflow + */ + detectPositivePromptInput(): PromptClapperInput[] { + const mainInputs = this.detectMainInputs() + const positivePromptInputs = mainInputs + .filter((input) => { + const deps = this.dependantList[input.nodeId] + return deps.some((dep) => dep.inputName === 'positive') + }) + .map((input) => { + input.type = 'positive' + return input + }) + + return positivePromptInputs + } + + /** + * Detect negative prompt inputs in the workflow + */ + detectNegativePromptInput(): PromptClapperInput[] { + const mainInputs = this.detectMainInputs() + const negativePromptInputs = mainInputs + .filter((input) => { + const deps = this.dependantList[input.nodeId] + return deps.some((dep) => dep.inputName === 'negative') + }) + .map((input) => { + input.type = 'negative' + return input + }) + + return negativePromptInputs + } + + /** + * Takes a workflow and converts it to PromptBuilder + */ + createPromptBuilder(): PromptBuilder { + const positivePrompts = this.detectPositivePromptInput() + const negativePrompts = this.detectNegativePromptInput() + const inputs = this.getInputs() + const outputNode = this.getOutputNode() + + const promptBuilder = new PromptBuilder( + this.workflow, + inputs.map((input) => input.name), + ['output'] + ) + + const processed: Record = {} + + positivePrompts.forEach((input) => { + processed['positive'] = true + promptBuilder.setInputNode( + 'positive', + `${input.nodeId}.inputs.${input.name}` + ) + promptBuilder.input('positive', input.value) + }) + + negativePrompts.forEach((input) => { + processed['negative'] = true + promptBuilder.setInputNode( + 'negative', + `${input.nodeId}.inputs.${input.name}` + ) + promptBuilder.input('negative', input.value) + }) + + inputs.forEach((input) => { + promptBuilder.setInputNode(input.name, input.key) + if (!processed[input.key]) { + promptBuilder.input(input.name, input.value) + } + }) + + if (outputNode) { + promptBuilder.setOutputNode('output', outputNode.id) + } + + return promptBuilder + } +}