Skip to content

Commit

Permalink
support attn mask for l+g/t5
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Aug 5, 2024
1 parent 231df19 commit da4d0fe
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
88 changes: 74 additions & 14 deletions library/strategy_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,51 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")

l_attn_mask = l_tokens["attention_mask"]
g_attn_mask = g_tokens["attention_mask"]
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]

return [l_tokens, g_tokens, t5_tokens]
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]


class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass

def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models

l_tokens, g_tokens, t5_tokens = tokens
l_tokens, g_tokens, t5_tokens = tokens[:3]
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
if l_tokens is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
if apply_lg_attn_mask:
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
lg_out = torch.cat([l_out, g_out], dim=-1)

if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
else:
t5_out = None

Expand All @@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"

def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask

def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX

def is_disk_cached_outputs_expected(self, abs_path: str):
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True

try:
npz = np.load(self.get_outputs_npz_path(abs_path))
if "clip_l" not in npz or "clip_g" not in npz:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
# t5xxl is optional
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
logger.error(f"Error loading file: {npz_path}")
raise e

return True

def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
l_out = lg_out[..., :768]
g_out = lg_out[..., 768:] # 1280
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
return np.concatenate([l_out, g_out], axis=-1)

def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)

def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None

if self.apply_lg_attn_mask:
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)

if self.apply_t5_attn_mask and t5_out is not None:
t5_attn_mask = data["t5_attn_mask"]
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)

return [lg_out, t5_out, lg_pooled]

def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]

clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
)

if lg_out.dtype == torch.bfloat16:
Expand All @@ -148,10 +196,22 @@ def cache_batch_outputs(
lg_pooled_i = lg_pooled[i]

if self.cache_to_disk:
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
clip_l_attn_mask=clip_l_attn_mask_i,
clip_g_attn_mask=clip_g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
**kwargs,
)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)

Expand Down
3 changes: 2 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def __init__(

# caching
self.caching_mode = None # None, 'latents', 'text'

self.tokenize_strategy = None
self.text_encoder_output_caching_strategy = None
self.latents_caching_strategy = None
Expand Down Expand Up @@ -1486,6 +1486,7 @@ def __getitem__(self, index):
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
)
text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs]
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
Expand Down
10 changes: 6 additions & 4 deletions sd3_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def do_sample(
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
Expand Down Expand Up @@ -323,15 +325,15 @@ def do_sample(
logger.info("Encoding prompts...")
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()

l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)

l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)

Expand Down
30 changes: 25 additions & 5 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def train(args):
args.text_encoder_batch_size,
False,
False,
False,
False,
)
)
train_dataset_group.set_current_strategies()
Expand Down Expand Up @@ -312,6 +314,8 @@ def train(args):
args.text_encoder_batch_size,
False,
train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

Expand All @@ -335,7 +339,11 @@ def train(args):
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_list = sd3_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list
sd3_tokenize_strategy,
[clip_l, clip_g, t5xxl],
tokens_list,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)

accelerator.wait_for_everyone()
Expand Down Expand Up @@ -748,21 +756,23 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):

if lg_out is None or (train_clip_l or train_clip_g):
# not cached or training, so get from text encoders
input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"]
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# TODO support weighted captions
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None]
sd3_tokenize_strategy,
[clip_l, clip_g, None],
[input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None],
)

if t5_out is None:
_, _, input_ids_t5xxl = batch["input_ids_list"]
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad():
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
_, t5_out, _ = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl]
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)

context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
Expand Down Expand Up @@ -969,6 +979,16 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
)

# TE training is disabled temporarily
# parser.add_argument(
Expand Down

0 comments on commit da4d0fe

Please sign in to comment.