Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions nnll/embeds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-License-Identifier:Apache-2.0
# original BFL Flux code from https://github.com/black-forest-labs/flux

import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel, T5TokenizerFast, T5EncoderModel


class HFEmbedder(nn.Module):
def __init__(self, version: str, max_length: int, **hf_kwargs):
super().__init__()
self.is_clip = version.startswith("openai")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
else:
self.tokenizer: T5TokenizerFast = T5TokenizerFast.from_pretrained(version, max_length=max_length)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)

self.hf_module = self.hf_module.eval().requires_grad_(False)

def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)

outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key].bfloat16()


def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)

img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)

img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)

vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}


def encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length=512,
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None,
return_index=-1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)

text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

prompt_embeds = text_encoder(text_input_ids.to(device), return_dict=True, output_hidden_states=True)

prompt_embeds = prompt_embeds.hidden_states[return_index]
if return_index != -1:
prompt_embeds = text_encoder.encoder.final_layer_norm(prompt_embeds)
prompt_embeds = text_encoder.encoder.dropout(prompt_embeds)

dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

_, seq_len, _ = prompt_embeds.shape

# duplicate text embeddings and attention mask for each generation per prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

return prompt_embeds
57 changes: 1 addition & 56 deletions nnll/init_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
# <!-- // /* d a r k s h a p e s */ -->

# pylint: disable=import-outside-toplevel
from nnll.console import nfo

from typing import Callable, Union, Literal, Optional, Any
from typing import Literal, Optional, Any


def soft_random(size: int = 0x100000000) -> int: # previously 0x2540BE3FF
Expand Down Expand Up @@ -35,11 +34,6 @@ def hard_random(hardness: int = 5) -> int:
return int(secrets.token_hex(hardness), 16) # make hex secret be int


#
# def c(dtype: str) -> torch.dtype:
# return {}.get(dtype)


def random_int_from_gpu(input_seed: int = soft_random()) -> int:
"""
Generate a random number via pytorch
Expand Down Expand Up @@ -118,52 +112,3 @@ def clear_cache(device_override: Optional[Literal["cuda", "mps", "cpu"]] = None)
torch.cuda.empty_cache()
if device.type == "mps" or device_override == "mps":
torch.mps.empty_cache()


# def first_available(processor: str = None, assign: bool = True, clean: bool = False, init: bool = True) -> Union[Callable, str]:
# """Return first available processor of the highest capacity\n
# :param processor: Name of an existing processing device, defaults to None (autodetect)
# :param assign: Direct torch to use the detected device, defaults to True
# :param clean: Clear any previous cache, defaults to False
# :param init: Initialize the device with a test tensor and discard, defaults to True\n
# :return: The torch device handler, or the name of the processor

# """
# from functools import reduce

# import torch

# torch.set_num_threads(1)
# if not processor:
# processor = reduce(
# lambda acc, check: check() if acc == "cpu" else acc,
# [
# lambda: "cuda" if torch.cuda.is_available() else "cpu",
# lambda: "mps" if torch.backends.mps.is_available() else "cpu",
# lambda: "xpu" if torch.xpu.is_available() else "cpu",
# lambda: "mtia" if torch.mtia.is_available() else "cpu",
# ],
# "cpu",
# )

# if clean:
# import gc

# gc.collect()
# if processor == "cuda":
# torch.cuda.empty_cache()
# if processor == "mps":
# torch.mps.empty_cache()
# if processor == "xpu":
# torch.xpu.empty_cache()
# if processor == "mtia":
# torch.mtia.empty_cache()

# if init:
# tensor = random_tensor_from_gpu(device=processor)
# if tensor:
# del tensor
# tensor = None

# nfo(f"Available torch devices: {processor}")
# return torch.device(processor) if assign else processor
10 changes: 5 additions & 5 deletions nnll/save_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def name_save_file_as(extension: ExtensionType, save_folder_path=".output") -> P


# do not log here
def write_to_disk(content: Any, metadata: dict[str], extension: str = None, **kwargs) -> None:
def write_to_disk(content: Any, metadata: dict[str], extension: ExtensionType | None = None, **kwargs) -> None:
"""Save Image to File\n
:param content: File data in memory
:param pipe_data: Pipe metadata to write into the file
```
name [ header: type : medium: type ]
# name [ header: type : medium: type ]

,-pipe dict
# ,-pipe dict
# \-model str ,-text string
# \-prompt dict___\-audio array
# `-kwargs dict \-image array
Expand All @@ -56,8 +56,8 @@ def write_to_disk(content: Any, metadata: dict[str], extension: str = None, **kw
import numpy as np
from PIL import Image

if not file_path_absolute.endswith(".webp"):
filename = file_path_absolute.rsplit(".", 1)[0] + ".webp"
if not file_path_absolute.endswith(ExtensionType.WEBP):
filename = file_path_absolute.rsplit(".", 1)[0] + ExtensionType.WEBP

# Convert tensor to PIL Image
if content.dim() == 4:
Expand Down