1515 SiglipVisionModel ,
1616)
1717
18- from divisor .flux_modules .util import print_load_warning
18+ from divisor .flux_modules .loading import print_load_warning
1919
2020
2121class DepthImageEncoder :
2222 depth_model_name = "LiheYoung/depth-anything-large-hf"
2323
2424 def __init__ (self , device ):
2525 self .device = device
26- self .depth_model = AutoModelForDepthEstimation .from_pretrained (
27- self .depth_model_name
28- ).to (device )
26+ self .depth_model = AutoModelForDepthEstimation .from_pretrained (self .depth_model_name ).to (device )
2927 self .processor = AutoProcessor .from_pretrained (self .depth_model_name )
3028
3129 def __call__ (self , img : torch .Tensor ) -> torch .Tensor :
@@ -37,9 +35,7 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
3735 img = self .processor (img_byte , return_tensors = "pt" )["pixel_values" ]
3836 depth = self .depth_model (img .to (self .device )).predicted_depth
3937 depth = repeat (depth , "b h w -> b 3 h w" )
40- depth = torch .nn .functional .interpolate (
41- depth , hw , mode = "bicubic" , antialias = True
42- )
38+ depth = torch .nn .functional .interpolate (depth , hw , mode = "bicubic" , antialias = True )
4339
4440 depth = depth / 127.5 - 1.0
4541 return depth
@@ -87,34 +83,24 @@ def __init__(
8783 super ().__init__ ()
8884
8985 self .redux_dim = redux_dim
90- self .device = (
91- device if isinstance (device , torch .device ) else torch .device (device )
92- )
86+ self .device = device if isinstance (device , torch .device ) else torch .device (device )
9387 self .dtype = dtype
9488
9589 with self .device :
9690 self .redux_up = nn .Linear (redux_dim , txt_in_features * 3 , dtype = dtype )
97- self .redux_down = nn .Linear (
98- txt_in_features * 3 , txt_in_features , dtype = dtype
99- )
91+ self .redux_down = nn .Linear (txt_in_features * 3 , txt_in_features , dtype = dtype )
10092
10193 sd = load_sft (redux_path , device = str (device ))
10294 missing , unexpected = self .load_state_dict (sd , strict = False , assign = True )
10395 print_load_warning (missing , unexpected )
10496
105- self .siglip = SiglipVisionModel .from_pretrained (self .siglip_model_name ).to (
106- dtype = dtype
107- )
97+ self .siglip = SiglipVisionModel .from_pretrained (self .siglip_model_name ).to (dtype = dtype )
10898 self .normalize = SiglipImageProcessor .from_pretrained (self .siglip_model_name )
10999
110100 def __call__ (self , x : Image .Image ) -> torch .Tensor :
111- imgs = self .normalize .preprocess (
112- images = [x ], do_resize = True , return_tensors = "pt" , do_convert_rgb = True
113- )
101+ imgs = self .normalize .preprocess (images = [x ], do_resize = True , return_tensors = "pt" , do_convert_rgb = True )
114102
115- _encoded_x = self .siglip (
116- ** imgs .to (device = self .device , dtype = self .dtype )
117- ).last_hidden_state
103+ _encoded_x = self .siglip (** imgs .to (device = self .device , dtype = self .dtype )).last_hidden_state
118104
119105 projected_x = self .redux_down (nn .functional .silu (self .redux_up (_encoded_x )))
120106
0 commit comments