Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Version = 0.0.5_2025-22-12
| Creepy Vibes... | Unacceptable. Words and flirts CAN hurt. End coercion. |
| Users vs Developers | Everyone involved, anywhere. Skill DIVERSITY, not division. |

\*More behavior guidelines
https://www.recurse.com/social-rules

## Constructive Criticism Guide:

- Ask consent first. Don't forget to wait for the answer!
Expand Down
7 changes: 0 additions & 7 deletions divisor/acestep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
"""
ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License
"""
18 changes: 6 additions & 12 deletions divisor/acestep/apg_guidance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-License-Identifier:Apache-2.0
# code from https://github.com/ace-step/ACE-Step

import torch


Expand Down Expand Up @@ -25,9 +28,7 @@ def project(
v1 = torch.nn.functional.normalize(v1, dim=dims)
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(
device_type
)
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)


def apg_forward(
Expand Down Expand Up @@ -67,15 +68,10 @@ def cfg_double_condition_forward(
guidance_scale_text,
guidance_scale_lyric,
):
return (
(1 - guidance_scale_text) * uncond_output
+ (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output
+ guidance_scale_lyric * cond_output
)
return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output


def optimized_scale(positive_flat, negative_flat):

# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)

Expand Down Expand Up @@ -104,7 +100,5 @@ def cfg_zero_star(
if (i <= zero_steps) and use_zero_init:
noise_pred = noise_pred_with_cond * 0.0
else:
noise_pred = noise_pred_uncond * alpha + guidance_scale * (
noise_pred_with_cond - noise_pred_uncond * alpha
)
noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha)
return noise_pred
25 changes: 15 additions & 10 deletions divisor/acestep/cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
# SPDX-License-Identifier:Apache-2.0
# adapted from https://github.com/ace-step/ACE-Step

import torch
import functools
from typing import Callable, TypeVar
from divisor.registry import gfx_sync, empty_cache


class CpuOffloader:
def __init__(self, model, device="cpu"):
self.model = model
self.original_device = device
self.original_dtype = model.dtype

def __enter__(self):
if not hasattr(self.model,"torchao_quantized"):
if not hasattr(self.model, "torchao_quantized"):
self.model.to(self.original_device, dtype=self.original_dtype)
return self.model

def __exit__(self, *args):
if not hasattr(self.model,"torchao_quantized"):
if not hasattr(self.model, "torchao_quantized"):
self.model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gfx_sync
empty_cache


T = TypeVar("T")

T = TypeVar('T')

def cpu_offload(model_attr: str):
def decorator(func: Callable[..., T]) -> Callable[..., T]:
Expand All @@ -35,9 +39,10 @@ def wrapper(self, *args, **kwargs):
device = self.device
# Get the model from the class attribute
model = getattr(self, model_attr)

with CpuOffloader(model, device):
return func(self, *args, **kwargs)

return wrapper

return decorator
8 changes: 2 additions & 6 deletions divisor/acestep/gradio.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""
ACE-Step: A Step Towards Music Generation Foundation Model
# SPDX-License-Identifier:Apache-2.0
# adapted from https://github.com/ace-step/ACE-Step

https://github.com/ace-step/ACE-Step

Apache 2.0 License
"""

import os

Expand Down
25 changes: 12 additions & 13 deletions divisor/acestep/pipeline_ace_step.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-License-Identifier:Apache-2.0
# adapted from https://github.com/ace-step/ACE-Step

"""
ACE-Step: A Step Towards Music Generation Foundation Model

https://github.com/ace-step/ACE-Step

Apache 2.0 License
"""

import json
Expand All @@ -14,16 +13,15 @@
import time
from typing import Literal

import torch
import torchaudio
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
retrieve_timesteps,
)
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from diffusers.utils.torch_utils import randn_tensor
from huggingface_hub import snapshot_download
from nnll.console import nfo
from nnll.init_gpu import clear_cache, device
import torch
import torchaudio
from tqdm import tqdm
from transformers import AutoTokenizer, UMT5EncoderModel

Expand All @@ -46,14 +44,15 @@
FlowMatchHeunDiscreteScheduler,
)
from divisor.acestep.schedulers.scheduling_flow_match_pingpong import FlowMatchPingPongScheduler
from divisor.registry import gfx_device, empty_cache

if device.type == "cuda":
if gfx_device.type == "cuda":
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
os.environ["TOKENIZERS_PARALLELISM"] = "false"
elif device.type == "mps":
elif gfx_device.type == "mps":
os.environ["DYLD_FALLBACK_LIBRARY_PATH"] = "/opt/homebrew/lib"

SUPPORT_LANGUAGES = {
Expand Down Expand Up @@ -116,13 +115,13 @@ def __init__(
self.lora_path = "none"
self.lora_weight = 1
self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
if device.type == "mps":
if gfx_device.type == "mps":
if self.dtype == torch.bfloat16:
self.dtype = torch.float16

if "ACE_PIPELINE_DTYPE" in os.environ and len(os.environ["ACE_PIPELINE_DTYPE"]):
self.dtype = getattr(torch, os.environ["ACE_PIPELINE_DTYPE"])
self.device = device
self.device: torch.device = gfx_device
self.loaded = False
self.torch_compile = torch_compile
self.cpu_offload = cpu_offload
Expand Down Expand Up @@ -205,8 +204,8 @@ def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False):
if self.torch_compile:
if export_quantized_weights:
from torch.ao.quantization import (
quantize_,
Int4WeightOnlyConfig,
quantize_,
)

group_size = 128
Expand Down Expand Up @@ -1475,7 +1474,7 @@ def __call__(
)

# Clean up memory after generation
clear_cache()
empty_cache

end_time = time.time()
latent2audio_time_cost = end_time - start_time
Expand Down
10 changes: 5 additions & 5 deletions divisor/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
Allows users to manually increment through timesteps one at a time.
"""

from dataclasses import asdict
import json
from dataclasses import asdict
from typing import Any, Callable, Optional

import torch
from nnll.console import nfo
from nnll.hyperchain import HyperChain
from nnll.init_gpu import device
from nnll.random import RNGState
import torch

from divisor.interaction_context import InteractionContext
from divisor.registry import gfx_device
from divisor.state import MenuState, StepState

rng = RNGState(device=device.type)
variation_rng = RNGState(device=device.type)
rng = RNGState(device=gfx_device.type)
variation_rng = RNGState(device=gfx_device.type)


def time_shift(
Expand Down
4 changes: 2 additions & 2 deletions divisor/denoise_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Optional

from einops import repeat
from nnll.init_gpu import device
from divisor.registry import gfx_device
import torch
from torch import Tensor

Expand Down Expand Up @@ -193,7 +193,7 @@ def get_prediction(
except (TypeError, StopIteration, AttributeError):
# Fallback: use sample dtype if we can't get model dtype (for Mock objects in tests)
model_dtype = sample.dtype
use_autocast = device.type == "cuda"
use_autocast = gfx_device.type == "cuda"

# Ensure sample is in correct dtype before any operations
if not use_autocast:
Expand Down
4 changes: 2 additions & 2 deletions divisor/dimoo/inference_mmu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import sys
import time

from huggingface_hub import snapshot_download
from nnll.init_gpu import device
import torch
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from divisor.dimoo.config import SPECIAL_TOKENS
from divisor.dimoo.prompt_utils import generate_text_prompt
from divisor.dimoo.text_understanding_generator import generate_text_understanding
from divisor.mmada.modeling_llada import LLaDAModelLM
from divisor.registry import gfx_device

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

Expand Down
5 changes: 2 additions & 3 deletions divisor/dimoo/text_understanding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

from typing import Optional

from nnll.init_gpu import device
import numpy as np
import torch
import torch.nn.functional as F

from divisor.contents import get_dtype
from divisor.registry import gfx_dtype, gfx_device
from divisor.mmada.live_token import get_num_transfer_tokens
from divisor.noise import add_gumbel_noise

Expand Down Expand Up @@ -46,7 +45,7 @@ def generate_text_understanding(
code_start: Prediction text token satrt index
"""
device = next(model.parameters()).device or device
precision = get_dtype(device)
precision = gfx_dtype
x = prompt

prompt_index = x != mask_id
Expand Down
31 changes: 17 additions & 14 deletions divisor/flux1/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,29 @@
import os
from pathlib import Path

import torch
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from huggingface_hub import snapshot_download
from nnll.console import nfo
from nnll.init_gpu import device
from safetensors.torch import load_file as load_sft
import torch

from divisor.flux1.autoencoder import AutoEncoder as AutoEncoder1, AutoEncoderParams
from divisor.flux1.autoencoder import AutoEncoder as AutoEncoder1
from divisor.flux1.autoencoder import AutoEncoderParams
from divisor.flux1.model import Flux, FluxLoraWrapper
from divisor.flux1.text_embedder import HFEmbedder
from divisor.flux2.autoencoder import (
AutoEncoder as AutoEncoder2,
)
from divisor.flux2.autoencoder import (
AutoEncoderParams as AutoEncoder2Params,
)
from divisor.flux2.model import Flux2, Flux2Params
from divisor.flux2.text_encoder import Mistral3SmallEmbedder
from divisor.spec import ModelSpec, CompatibilitySpec, optionally_expand_state_dict
from divisor.mmada.modeling_mmada import MMadaConfig as MMaDAParams
from divisor.mmada.modeling_mmada import MMadaModelLM as MMaDAModelLM
from divisor.registry import gfx_device, gfx_dtype
from divisor.spec import ModelSpec, optionally_expand_state_dict
from divisor.xflux1.model import XFlux, XFluxParams
from divisor.contents import get_dtype
from divisor.mmada.modeling_mmada import MMadaConfig as MMaDAParams, MMadaModelLM as MMaDAModelLM


def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
Expand Down Expand Up @@ -112,7 +115,7 @@ def load_lora_weights(
model: FluxLoraWrapper,
lora_repo_id: str,
lora_filename: str,
device: str | torch.device = device,
device: str | torch.device = gfx_device,
verbose: bool = True,
) -> None:
"""Load LoRA weights into a FluxLoraWrapper model.\n
Expand All @@ -131,7 +134,7 @@ def load_lora_weights(

def load_flow_model(
model_spec: ModelSpec,
device: torch.device = device,
device: torch.device = gfx_device,
verbose: bool = True,
lora_repo_id: str | None = None,
lora_filename: str | None = None,
Expand Down Expand Up @@ -173,7 +176,7 @@ def load_flow_model(

def load_ae(
model_spec: ModelSpec,
device: torch.device = device,
device: torch.device = gfx_device,
) -> AutoEncoder1 | AutoEncoder2 | AutoencoderTiny:
"""Load the autoencoder model.\n
:param mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev")
Expand Down Expand Up @@ -203,7 +206,7 @@ def load_ae(

def load_mmada_model(
model_spec: ModelSpec,
device: torch.device = device,
device: torch.device = gfx_device,
) -> MMaDAModelLM:
"""Load a MMaDA model\n
:param model_spec: ModelSpec object containing model details
Expand All @@ -213,7 +216,7 @@ def load_mmada_model(
:returns: Loaded MMaDA model
:raises: TypeError if model_spec.params is not a MMaDAParams
"""
precision = get_dtype(device)
precision = gfx_dtype
if isinstance(model_spec.params, MMaDAParams):
model_spec.params.llm_model_path = model_spec.repo_id
model = MMaDAModelLM.from_pretrained(model_spec.repo_id, dtype=precision) # type: ignore
Expand All @@ -224,14 +227,14 @@ def load_mmada_model(
raise TypeError(f"MMaDA params not found for: {model_spec.repo_id} with params type {type(model_spec.params).__name__}")


def load_t5(device: str | torch.device = device, max_length: int = 512) -> HFEmbedder:
def load_t5(device: str | torch.device = gfx_device, max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, dtype=torch.bfloat16).to(device)


def load_clip(device: str | torch.device = device) -> HFEmbedder:
def load_clip(device: str | torch.device = gfx_device) -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=torch.bfloat16).to(device)


def load_mistral_small_embedder(device: str | torch.device = device) -> Mistral3SmallEmbedder:
def load_mistral_small_embedder(device: str | torch.device = gfx_device) -> Mistral3SmallEmbedder:
return Mistral3SmallEmbedder().to(device)
Loading