Skip to content

Commit a238aa8

Browse files
committed
instantid
1 parent 174f5d4 commit a238aa8

File tree

9 files changed

+1617
-2
lines changed

9 files changed

+1617
-2
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ projects/*
164164
!projects/.gitkeep
165165
models/*
166166
!models/.gitkeep
167+
checkpoints/*
168+
!checkpoints/.gitkeep
167169
uploads/*
168170
!uploads/.gitkeep
169171
.DS_Store

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.PHONY: start
22
start:
3-
poetry run python main.py
3+
poetry run uvicorn app.main:app --port 9000 --workers 4
44

55
.PHONY: dev
66
dev:

app/brain.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from app.llms.tools.dalle import DalleImage
2121
from app.llms.tools.describeimage import DescribeImage
2222
from app.llms.tools.drawimage import DrawImage
23+
from app.llms.tools.instantid import InstantID
2324
from app.llms.tools.refineimage import RefineImage
2425
from app.llms.tools.stablediffusion import StableDiffusionImage
2526
from app.model import Model
@@ -375,6 +376,7 @@ def entryVision(self, projectName, visionInput, isprivate, db: Session):
375376
RefineImage(),
376377
DrawImage(),
377378
DescribeImage(),
379+
InstantID(),
378380
]
379381

380382
if isprivate:

app/llms/instantid/pipeline_stable_diffusion_xl_instantid.py

Lines changed: 771 additions & 0 deletions
Large diffs are not rendered by default.

app/llms/instantid/worker.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import base64
2+
import io
3+
from diffusers.models import ControlNetModel
4+
from huggingface_hub import hf_hub_download
5+
6+
import cv2
7+
import torch
8+
import numpy as np
9+
import random
10+
from PIL import Image
11+
from insightface.app import FaceAnalysis
12+
13+
from app.llms.instantid.pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
14+
15+
def instantid_worker(prompt, sharedmem):
16+
try:
17+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
18+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
19+
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
20+
except:
21+
pass
22+
23+
img_data = base64.b64decode(sharedmem["input_image"])
24+
face_image = Image.open(io.BytesIO(img_data))
25+
26+
prompt_default = ", (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)"
27+
prompt = prompt + prompt_default
28+
29+
negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green"
30+
31+
DEFAULT_CUDA = "cuda"
32+
33+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
34+
app.prepare(ctx_id=0, det_size=(640, 640))
35+
36+
face_adapter = f'./checkpoints/ip-adapter.bin'
37+
controlnet_path = f'./checkpoints/ControlNetModel'
38+
39+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
40+
41+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
42+
"wangqixun/YamerMIX_v8", controlnet=controlnet, torch_dtype=torch.float16
43+
)
44+
pipe.to(DEFAULT_CUDA)
45+
46+
pipe.load_ip_adapter_instantid(face_adapter)
47+
48+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
49+
pipe.disable_lora()
50+
51+
face_image_cv2 = cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)
52+
height, width, _ = face_image_cv2.shape
53+
54+
face_info = app.get(face_image_cv2)
55+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
56+
face_emb = face_info['embedding']
57+
face_kps = draw_kps(face_image, face_info['kps'])
58+
59+
control_mask = np.zeros([height, width, 3])
60+
x1, y1, x2, y2 = face_info["bbox"]
61+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
62+
control_mask[y1:y2, x1:x2] = 255
63+
control_mask = Image.fromarray(control_mask.astype(np.uint8))
64+
65+
pipe.set_ip_adapter_scale(0.8)
66+
67+
generator = torch.Generator(device=DEFAULT_CUDA).manual_seed(random.randint(0, np.iinfo(np.int32).max))
68+
69+
image = pipe(
70+
prompt,
71+
image_embeds=face_emb,
72+
image=face_kps,
73+
control_mask=control_mask,
74+
num_inference_steps=50,
75+
controlnet_conditioning_scale=0.8,
76+
negative_prompt=negative_prompt,
77+
generator=generator,
78+
guide_scale=0,
79+
height=height,
80+
width=width,
81+
).images[0]
82+
83+
output_img_data = io.BytesIO()
84+
image.save(output_img_data, format="JPEG")
85+
image_base64 = base64.b64encode(output_img_data.getvalue()).decode('utf-8')
86+
87+
sharedmem["output_image"] = image_base64

app/llms/tools/instantid.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from torch.multiprocessing import Process, set_start_method, Manager
2+
3+
from app.llms.instantid.worker import instantid_worker
4+
try:
5+
set_start_method('spawn')
6+
except RuntimeError:
7+
pass
8+
from langchain.tools import BaseTool
9+
from langchain.chains import LLMChain
10+
from langchain_community.chat_models import ChatOpenAI
11+
from langchain.prompts import PromptTemplate
12+
13+
from typing import Optional
14+
from langchain.callbacks.manager import (
15+
CallbackManagerForToolRun,
16+
)
17+
18+
19+
class InstantID(BaseTool):
20+
name = "Avatar Generator"
21+
description = "use this tool when you need to draw an avatar from an image and a descripton."
22+
return_direct = True
23+
24+
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
25+
if run_manager.tags[0].boost == True:
26+
llm = ChatOpenAI(temperature=0.9, model_name="gpt-3.5-turbo")
27+
prompt = PromptTemplate(
28+
input_variables=["image_desc"],
29+
template="Generate a detailed prompt to generate an image based on the following description: {image_desc}",
30+
)
31+
chain = LLMChain(llm=llm, prompt=prompt)
32+
33+
fprompt = chain.run(query)
34+
else:
35+
fprompt = run_manager.tags[0].question
36+
37+
manager = Manager()
38+
sharedmem = manager.dict()
39+
40+
sharedmem["input_image"] = run_manager.tags[0].image
41+
42+
p = Process(target=instantid_worker, args=(fprompt, sharedmem))
43+
p.start()
44+
p.join()
45+
46+
return {"type": "instantid", "image": sharedmem["output_image"], "prompt": fprompt}
47+
48+
async def _arun(self, query: str) -> str:
49+
raise NotImplementedError("N/A")

checkpoints/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)