Skip to content

Commit 792e6d8

Browse files
authored
Merge pull request #14 from darkshapes/x/alt-model-loading
~selective model downloading, dynamic prompt
2 parents 553d1c3 + 7eddcde commit 792e6d8

8 files changed

Lines changed: 317 additions & 276 deletions

File tree

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
Generate:
2-
31
```
42
uvx --from "divisor @ git+https://github.com/darkshapes/divisor" divisor
53
```
64

7-
Develop:
5+
or
86

97
```
108
git clone https://github.com/darkshapes/divisors
119
cd divisor
1210
uv sync --dev
1311
dvzr
1412
```
13+
14+
[![dvzr pytest](https://github.com/darkshapes/divisor/actions/workflows/divisor.yml/badge.svg)](https://github.com/darkshapes/divisor/actions/workflows/divisor.yml)<br>
15+
[<img src="https://img.shields.io/badge/me-__?logo=kofi&logoColor=white&logoSize=auto&label=feed&labelColor=maroon&color=grey&link=https%3A%2F%2Fko-fi.com%2Fdarkshapes">](https://ko-fi.com/darkshapes)<br>

divisor/app.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Routes to different inference modes based on flags.
77
"""
88

9+
import argparse
910
import sys
1011
from fire import Fire
1112

@@ -14,21 +15,51 @@ def main():
1415
"""Main entry point that routes to appropriate inference function.
1516
1617
Usage:
17-
dvzr # Default: Flux image generation mode
18-
dvzr -o / --omni # DiMOO multimodal understanding mode
18+
dvzr # Default: Flux image generation mode
19+
dvzr -o / --omni # DiMOO multimodal understanding mode
20+
dvzr --model-type dev # Use flux1-dev model (default)
21+
dvzr --model-type schnell # Use flux1-schnell model
22+
dvzr -m dev # Short form for model type
1923
"""
20-
# Check for --omni or -o flag (as standalone arguments, not part of other args)
21-
has_omni_flag = any(arg in ["-o", "--omni"] for arg in sys.argv)
24+
parser = argparse.ArgumentParser(description="Divisor CLI - Flux image generation and multimodal understanding")
25+
parser.add_argument(
26+
"-o",
27+
"--omni",
28+
action="store_true",
29+
help="Enable DiMOO multimodal understanding mode",
30+
)
31+
parser.add_argument(
32+
"-m",
33+
"--model-type",
34+
choices=["dev", "schnell"],
35+
default="dev",
36+
help="Model type to use: 'dev' (flux1-dev) or 'schnell' (flux1-schnell). Default: dev",
37+
)
2238

23-
if has_omni_flag:
24-
original_argv = sys.argv.copy()
25-
filtered_argv = [arg for arg in original_argv if arg not in ["-o", "--omni"]]
26-
sys.argv = filtered_argv
39+
# Parse known args to separate our args from Fire's args
40+
args, remaining_argv = parser.parse_known_args()
2741

28-
from divisor.flux_modules.prompt import main as flux_main
42+
if args.omni:
43+
# Remove --omni/-o from argv and route to omni mode
44+
filtered_argv = [arg for arg in sys.argv[1:] if arg not in ["-o", "--omni"]]
45+
sys.argv = [sys.argv[0]] + filtered_argv
46+
# TODO: Import and call omni main function when implemented
47+
# from divisor.omni_modules.prompt import main as omni_main
48+
# Fire(omni_main)
49+
raise NotImplementedError("Omni mode not yet implemented")
50+
else:
51+
# Route to Flux mode
52+
from divisor.flux_modules.prompt import main as flux_main
2953

30-
# Flux uses Fire, which automatically handles sys.argv
31-
Fire(flux_main)
54+
# Add model_id argument to remaining argv for Fire to parse
55+
# Fire converts underscores to hyphens, so model_id becomes --model-id
56+
model_id = f"flux1-{args.model_type}"
57+
# Insert model_id argument before other arguments
58+
remaining_argv = ["--model-id", model_id] + remaining_argv
59+
sys.argv = [sys.argv[0]] + remaining_argv
60+
61+
# Flux uses Fire, which automatically handles sys.argv
62+
Fire(flux_main)
3263

3364

3465
if __name__ == "__main__":

divisor/commands.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ def process_choice(
7373
if state.prompt is not None:
7474
prompt_display = state.prompt[:60] + "..." if len(state.prompt) > 60 else state.prompt
7575
nfo(f"[P]rompt: {prompt_display}")
76-
77-
choice = input(": [BDGLSRVXP] advance with Enter: ").lower().strip()
76+
nfo(f"[E]dit Mode (REPL): {prompt_display}")
7877

7978
choice_handlers = {
8079
"": lambda: (
@@ -84,16 +83,18 @@ def process_choice(
8483
controller.current_state,
8584
),
8685
"g": lambda: change_guidance(controller, state, clear_prediction_cache),
87-
"l": lambda: change_layer_dropout(controller, state, current_layer_dropout, clear_prediction_cache),
88-
"r": lambda: change_resolution(controller, state, clear_prediction_cache),
8986
"s": lambda: change_seed(controller, state, rng, clear_prediction_cache),
87+
"r": lambda: change_resolution(controller, state, clear_prediction_cache),
88+
"l": lambda: change_layer_dropout(controller, state, current_layer_dropout, clear_prediction_cache),
9089
"b": lambda: toggle_buffer_mask(controller, state),
9190
"a": lambda: change_vae_offset(controller, state, ae, clear_prediction_cache),
9291
"v": lambda: change_variation(controller, state, variation_rng, clear_prediction_cache),
9392
"d": lambda: toggle_deterministic(controller, state, clear_prediction_cache),
94-
"e": lambda: edit_mode(clear_prediction_cache),
9593
"p": lambda: change_prompt(controller, state, clear_prediction_cache, recompute_text_embeddings),
94+
"e": lambda: edit_mode(clear_prediction_cache),
9695
}
96+
prompt = "".join(key.upper() for key in choice_handlers if key)
97+
choice = input(f": [{prompt}] or advance with Enter: ").lower().strip()
9798

9899
if choice in choice_handlers:
99100
result = choice_handlers[choice]()

divisor/flux_modules/image_embedder.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@
1515
SiglipVisionModel,
1616
)
1717

18-
from divisor.flux_modules.util import print_load_warning
18+
from divisor.flux_modules.loading import print_load_warning
1919

2020

2121
class DepthImageEncoder:
2222
depth_model_name = "LiheYoung/depth-anything-large-hf"
2323

2424
def __init__(self, device):
2525
self.device = device
26-
self.depth_model = AutoModelForDepthEstimation.from_pretrained(
27-
self.depth_model_name
28-
).to(device)
26+
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
2927
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
3028

3129
def __call__(self, img: torch.Tensor) -> torch.Tensor:
@@ -37,9 +35,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
3735
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
3836
depth = self.depth_model(img.to(self.device)).predicted_depth
3937
depth = repeat(depth, "b h w -> b 3 h w")
40-
depth = torch.nn.functional.interpolate(
41-
depth, hw, mode="bicubic", antialias=True
42-
)
38+
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
4339

4440
depth = depth / 127.5 - 1.0
4541
return depth
@@ -87,34 +83,24 @@ def __init__(
8783
super().__init__()
8884

8985
self.redux_dim = redux_dim
90-
self.device = (
91-
device if isinstance(device, torch.device) else torch.device(device)
92-
)
86+
self.device = device if isinstance(device, torch.device) else torch.device(device)
9387
self.dtype = dtype
9488

9589
with self.device:
9690
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
97-
self.redux_down = nn.Linear(
98-
txt_in_features * 3, txt_in_features, dtype=dtype
99-
)
91+
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
10092

10193
sd = load_sft(redux_path, device=str(device))
10294
missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
10395
print_load_warning(missing, unexpected)
10496

105-
self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(
106-
dtype=dtype
107-
)
97+
self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
10898
self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
10999

110100
def __call__(self, x: Image.Image) -> torch.Tensor:
111-
imgs = self.normalize.preprocess(
112-
images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True
113-
)
101+
imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
114102

115-
_encoded_x = self.siglip(
116-
**imgs.to(device=self.device, dtype=self.dtype)
117-
).last_hidden_state
103+
_encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
118104

119105
projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
120106

0 commit comments

Comments
 (0)