diff --git a/torchchat/cli/convert_hf_checkpoint.py b/torchchat/cli/convert_hf_checkpoint.py index f90b59c25..7001eb1e1 100644 --- a/torchchat/cli/convert_hf_checkpoint.py +++ b/torchchat/cli/convert_hf_checkpoint.py @@ -7,10 +7,13 @@ import os import re import sys +import glob from pathlib import Path -from typing import Optional +from typing import Any, Dict, Optional import torch +import safetensors.torch +import shutil from torchchat.model import TransformerArgs @@ -21,9 +24,176 @@ from torchchat.model import ModelArgs +def remap_llava_checkpoint(llava_ckpt): + def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]: + translated_state_dict = {} + hf_weight_prefix = "vision_model." + name_mapping = { + f"{hf_weight_prefix}embeddings.class_embedding": "encoder.cls_token_embedding.weight", + f"{hf_weight_prefix}embeddings.position_embedding.weight": "encoder.token_pos_embedding.positional_embedding", + f"{hf_weight_prefix}embeddings.patch_embedding.weight": "encoder.conv.weight", + f"{hf_weight_prefix}pre_layrnorm.weight": "encoder.ln_pre.weight", + f"{hf_weight_prefix}pre_layrnorm.bias": "encoder.ln_pre.bias", + f"{hf_weight_prefix}post_layernorm.weight": "encoder.ln_post.weight", + f"{hf_weight_prefix}post_layernorm.bias": "encoder.ln_post.bias", + } + patterns = [ + ( + rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)", + lambda match: f"encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}", + ), + ( + rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)", + lambda match: f"encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}", + ), + ( + rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)", + lambda match: f"encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}", + ), + ( + rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)", + lambda match: f"encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}", + ), + ( + rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)", + lambda match: f"encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}", + ), + ] + for pattern, replacement in patterns: + for key in list(hf_state_dict.keys()): + if re.match(pattern, key): + new_key = re.sub(pattern, replacement, key) + name_mapping[key] = new_key + temp_state_dict = {} + for k, v in hf_state_dict.items(): + new_k = name_mapping.get(k, k) + if "in_proj_weight" in new_k or "in_proj_bias" in new_k: + if new_k not in temp_state_dict: + temp_state_dict[new_k] = {"q": None, "k": None, "v": None} + if "q_proj" in k: + temp_state_dict[new_k]["q"] = v + elif "k_proj" in k: + temp_state_dict[new_k]["k"] = v + elif "v_proj" in k: + temp_state_dict[new_k]["v"] = v + else: + temp_state_dict[new_k] = v + for k, v in temp_state_dict.items(): + if isinstance(v, dict): + translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0) + else: + translated_state_dict[k] = v + return translated_state_dict + + def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]: + key_map = { + r"model.layers.([0-9]+).self_attn.q_proj.": r"decoder.layers.\1.attention.wq.", + r"model.layers.([0-9]+).self_attn.k_proj.": r"decoder.layers.\1.attention.wk.", + r"model.layers.([0-9]+).self_attn.v_proj.": r"decoder.layers.\1.attention.wv.", + r"model.layers.([0-9]+).self_attn.o_proj.": r"decoder.layers.\1.attention.wo.", + r"model.layers.([0-9]+).input_layernorm.": r"decoder.layers.\1.attention_norm.", + r"model.layers.([0-9]+).mlp.gate_proj.": r"decoder.layers.\1.feed_forward.w1.", + r"model.layers.([0-9]+).mlp.down_proj.": r"decoder.layers.\1.feed_forward.w2.", + r"model.layers.([0-9]+).mlp.up_proj.": r"decoder.layers.\1.feed_forward.w3.", + r"model.layers.([0-9]+).post_attention_layernorm.": r"decoder.layers.\1.ffn_norm.", + r"model.norm.": r"decoder.norm.", + # r"model.embed_tokens.": r"tok_embeddings.", # load separately + r"lm_head.": r"decoder.output.", + } + new_state_dict = {} + def get_new_key(old_key: str) -> str: + for old_pattern, replacement in key_map.items(): + if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: + return new_key + return old_key + for old_key in hf_state_dict.keys(): + new_key = get_new_key(old_key) + new_state_dict[new_key] = hf_state_dict[old_key] + return new_state_dict + + def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]: + new_state_dict = {} + for old_key in hf_state_dict.keys(): + new_key = "mm_projector." + old_key + new_state_dict[new_key] = hf_state_dict[old_key] + return new_state_dict + + def split_checkpoint(llava_ckpt): + language_model_ckpt = {} + multi_modal_ckpt = {} + vision_tower_ckpt = {} + for key, value in llava_ckpt.items(): + if key.startswith("language_model"): + language_model_ckpt[key[len("language_model") + 1:]] = value + elif key.startswith("multi_modal_projector"): + multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value + elif key.startswith("vision_tower"): + vision_tower_ckpt[key[len("vision_tower") + 1:]] = value + return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt + language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt) + remapped_state_dict = { + "tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"), + } + remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt)) + remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt)) + remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt)) + return remapped_state_dict + + +@torch.inference_mode +def convert_llava_checkpoint( + *, + model_dir: Optional[Path] = None, +) -> None: + + """ + Process safetensor files from a specific directory structure and save the remapped model. + + Args: + model_dir (str): Base directory containing the model subdirectories. + """ + + def _get_llava_files_with_pattern(pattern): + pattern = os.path.join(model_dir, f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{pattern}") + return glob.glob(pattern) + + # get all safetensor files in the model directory + safetensor_files = _get_llava_files_with_pattern("*.safetensors") + + if not safetensor_files: + raise ValueError("No safetensor files found.") + + merged_weights = {} + + # Merge safetensor files into a whole + for file in safetensor_files: + # Load weights from the current file + part_weights = safetensors.torch.load_file(file) + + # Iterate over each weight in the current file + for key, value in part_weights.items(): + if key in merged_weights: + # If the key already exists, concatenate tensors + merged_weights[key] = torch.cat((merged_weights[key], value), dim=0) + else: + # If the key does not exist, add it to the dictionary + merged_weights[key] = value + + # Remap the checkpoint and save it as pth + remapped_weights = remap_llava_checkpoint(merged_weights) + model_path = model_dir / "model.pth" + torch.save(remapped_weights, model_path) + + # copy tokenizer + tokenizer_files = _get_llava_files_with_pattern("tokenizer.model") + assert len(tokenizer_files) == 1, "Should get only one tokenizer file, but got {}".format(tokenizer_files) + + tokenizer_path = model_dir / "tokenizer.model" + shutil.copy(tokenizer_files[0], tokenizer_path) + @torch.inference_mode() -def convert_hf_checkpoint( +def convert_text_only_hf_checkpoint( *, model_dir: Optional[Path] = None, model_name: Optional[str] = None, @@ -132,6 +302,19 @@ def permute(w, n_heads): os.remove(file) +@torch.inference_mode() +def convert_hf_checkpoint( + *, + model_dir: Optional[Path] = None, + model_name: Optional[str] = None, + remove_bin_files: bool = False, +): + if "llava" in model_name: + convert_llava_checkpoint(model_dir=model_dir) + else: + convert_text_only_hf_checkpoint(model_dir=model_dir, model_name=model_name, remove_bin_files=remove_bin_files) + + if __name__ == "__main__": import argparse diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 4a8f43515..8d7ab79c3 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -33,8 +33,9 @@ def _download_hf_snapshot( local_dir=artifact_dir, local_dir_use_symlinks=False, token=hf_token, - ignore_patterns="*safetensors*", + ignore_patterns=None if "llava" in model_config.name else "*safetensors*", ) + except HTTPError as e: if e.response.status_code == 401: # Missing HuggingFace CLI login. print( diff --git a/torchchat/generate.py b/torchchat/generate.py index 9e60f9494..be1cab606 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -36,6 +36,7 @@ from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +from torchchat.utils.preprocessors import llava_image_preprocess # torchtune model definition dependencies from torchtune.data import Message @@ -357,8 +358,13 @@ def prefill( if batch is not None: # TODO: Verify sequential prefill works with multimodal models - logits = model(**batch)[:, -1] - return tune_sample(logits, 0, 500) + logits = model(**batch) + if model.config.model_type == ModelType.Llava: + context_len, logits = logits[0], logits[1][:, -1] + return context_len, tune_sample(logits, 0, 500) + else: + logits = logits[:, -1] + return tune_sample(logits, 0, 500) elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) @@ -622,6 +628,13 @@ def generate( sequential_prefill=sequential_prefill, **sampling_kwargs, ) + + # For llava with image input, we need to extract next pos id from prefill result + if batch and self.model.config.model_type == ModelType.Llava: + context_len, next_token = next_token + else: + context_len, next_token = T, next_token + if is_speculative: self.prefill( draft_model, @@ -636,7 +649,7 @@ def generate( # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) - input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int) + input_pos = torch.tensor([start_pos + context_len], device=device, dtype=torch.int) accept_counts = [0] * ( speculate_k + 1 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long @@ -726,27 +739,56 @@ def chat( if generator_args.image_prompts is not None: print("Image prompts", generator_args.image_prompts) - # Support for just the first image prompt for now images = [Image.open(generator_args.image_prompts[0])] - messages = [ - Message( - role="user", - content=[ - {"type": "image", "content": images[0]}, - {"type": "text", "content": generator_args.prompt}, - ], - eot=True, - ), - Message(role="assistant", content=""), - ] - transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) - data = transform({"messages": messages}, inference=True) - batch = padded_collate([data], self.builder_args.device) - batch.pop("mask") - encoded = batch["tokens"] + assert len(images) == 1, "Only one image prompt is supported for now" + + #TODO: updated encoded variable for multi-modality models to include image tokens. + if self.model.config.model_type == ModelType.Flamingo: + messages = [ + Message( + role="user", + content=[ + {"type": "image", "content": images[0]}, + {"type": "text", "content": generator_args.prompt}, + ], + eot=True, + ), + Message(role="assistant", content=""), + ] + transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) + data = transform({"messages": messages}, inference=True) + batch = padded_collate([data], self.builder_args.device) + batch.pop("mask") + encoded = batch["tokens"] + elif self.model.config.model_type == ModelType.Llava: + #TODO: double check the tokenizer. + def find_subtensor(tensor, target): + target_len = len(target) + for i in range(len(tensor) - target_len + 1): + if torch.all(tensor[i:i+target_len] == target): + return i + return -1 + + input_ids = self.encode_tokens(generator_args.prompt, bos=True, device=self.builder_args.device) + image_token_indices = self.encode_tokens("", device=self.builder_args.device)[1:] + index = find_subtensor(input_ids, image_token_indices) + + if index == -1: + raise ValueError("Image token not found in prompt") + + batch = { + "tokens": input_ids[:index].unsqueeze(0), + "encoder_input": llava_image_preprocess(images[0], device=self.builder_args.device, dtype=self.builder_args.precision), + "post_tokens": input_ids[index + len(image_token_indices) :].unsqueeze(0), + } + + # can not get actual encoded image feature before model inference; pseudo one + pseudo_vision_encoded = torch.zeros(1, 624).to(device=self.builder_args.device, dtype=self.builder_args.precision) + encoded = torch.cat([batch["tokens"].view(1, -1), pseudo_vision_encoded, batch["post_tokens"].view(1, -1)], dim=-1).view(-1) + else: encoded = self.encode_tokens( generator_args.prompt, bos=True, device=self.builder_args.device diff --git a/torchchat/model.py b/torchchat/model.py index 3300ebee9..e14b0d2b8 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -14,7 +14,7 @@ import torchvision -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from collections.abc import Hashable import torch @@ -56,7 +56,6 @@ def identity(**kwargs): return list(kwargs.values())[0] - class MultiModalProjector(nn.Module): def __init__(self, in_channels: int, out_channels: int, act: nn.Module): super().__init__() @@ -126,7 +125,10 @@ def forward( dtype=torch.int, ) - return self.decoder(decoder_input, input_pos=input_pos) + return decoder_input.shape[1], self.decoder(decoder_input, input_pos=input_pos) + else: + return self.decoder(decoder_input, input_pos=input_pos) + def setup_caches(self, batch_size, max_seq_len) -> None: self.decoder.setup_caches(batch_size, max_seq_len) @@ -262,6 +264,7 @@ class TransformerArgs: use_tiktoken: bool = False max_seq_length: int = 8192 rope_scaling: Optional[Dict[str, Any]] = None + use_hf_rope: bool = False # For pipeline parallel n_stages: int = 1 stage_idx: int = 0 @@ -413,7 +416,6 @@ def __init__( # print(f"dtype on entry {dtype}") if not dtype: dtype = get_precision() - # print(f"dtype on get_prec {dtype}") cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) @@ -553,13 +555,16 @@ def reset_caches(self): class LlavaModel(Model): + def __init__(self, config: ModelArgs) -> None: + super().__init__(config) + self.text_transformer_args = self.model.decoder.config + def forward( self, tokens: Tensor, - *, + input_pos: Optional[Tensor] = None, encoder_input: Optional[Dict[str, Tensor]] = None, post_tokens: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, ) -> Tensor: return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) @@ -605,6 +610,13 @@ def __init__(self, config: TransformerArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 + # For supporting sequence parallel (default is off, thus value of 1) + self.seq_parallel_degree = 1 + if config.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = precompute_freqs_cis + def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): if ( @@ -623,7 +635,7 @@ def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): max_batch_size, max_seq_length, cache_lanes=cache_lanes ) - freqs_cis = precompute_freqs_cis( + freqs_cis = self.precompute_freqs_cis( self.config.dim // self.config.n_heads, self.config.block_size * 2, self.config.rope_base, @@ -657,8 +669,10 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int assert self.freqs_cis is not None, "Caches must be initialized first" mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] + if self.tok_embeddings: x = self.tok_embeddings(x) + for _, layer in self.layers.items(): x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane) @@ -715,6 +729,10 @@ def __init__(self, config: TransformerArgs): self.n_local_heads = config.n_local_heads self.dim = config.dim self._register_load_state_dict_pre_hook(self.load_hook) + if config.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = apply_rotary_emb def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1): n_local_heads = self.n_local_heads @@ -798,8 +816,8 @@ def forward( # -1 = self.n_local_heads v = v.view(bsz, seqlen, -1, self.head_dim) - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) + q = self.apply_rotary_emb(q, freqs_cis) + k = self.apply_rotary_emb(k, freqs_cis) q, k, v = (x.transpose(1, 2) for x in (q, k, v)) @@ -919,6 +937,58 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: return x_out2.type_as(x) + +# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 +def hf_precompute_freqs_cis(dim: int, end: int, theta: float, dtype=None, **kwargs): + if not dtype: + dtype = get_precision() + + freqs = 1.0 / ( + theta + ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) + ) + # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. + t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( + freqs # pyre-ignore + ) + freqs = torch.outer(t, freqs).float() # pyre-ignore + emb = torch.cat((freqs, freqs), dim=-1) + freqs_cos = torch.cos(emb) + freqs_sin = torch.sin(emb) + return torch.stack((freqs_cos, freqs_sin), dim=-1).to(dtype=dtype) + +# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135 +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def hf_apply_rotary_emb(x, freq_cis, unsqueeze_dim=1, **kwargs): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = freq_cis[..., 0].unsqueeze(unsqueeze_dim) + sin = freq_cis[..., 1].unsqueeze(unsqueeze_dim) + x_out = (x * cos) + (rotate_half(x) * sin) + return x_out.type_as(x) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ExecuTorch model components # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torchchat/model_config/model_config.py b/torchchat/model_config/model_config.py index 584a87a74..079f31629 100644 --- a/torchchat/model_config/model_config.py +++ b/torchchat/model_config/model_config.py @@ -86,6 +86,6 @@ def resolve_model_config(model: str) -> ModelConfig: model = model_aliases[model] if model not in model_configs: - raise ValueError(f"Unknown model '{model}'.") + raise ValueError(f"Unknown model '{model}'. Supported models: {model_configs.keys()}") return model_configs[model] diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index ca8c5acdf..f437d43ca 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -1,4 +1,10 @@ { + "llava-hf/llava-1.5-7b-hf": { + "aliases": ["llava-1.5"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "llava-hf/llava-1.5-7b-hf", + "transformer_params_key": "llava-1.5" + }, "meta-llama/Llama-2-7b-hf": { "aliases": ["llama2-base", "llama2-7b"], "distribution_channel": "HuggingFaceSnapshot", diff --git a/torchchat/model_params/llava-1.5.json b/torchchat/model_params/llava-1.5.json index 992cc2c69..c84889452 100644 --- a/torchchat/model_params/llava-1.5.json +++ b/torchchat/model_params/llava-1.5.json @@ -1,6 +1,5 @@ { "model_type": "llava", - "use_tiktoken": true, "encoder": { "tile_size": 336, "patch_size": 14, @@ -20,6 +19,7 @@ "n_heads": 32, "dim": 4096, "vocab_size": 32064, - "max_seq_length": 768 + "max_seq_length": 768, + "use_hf_rope": true } } diff --git a/torchchat/utils/preprocessors.py b/torchchat/utils/preprocessors.py new file mode 100644 index 000000000..abca2a7ea --- /dev/null +++ b/torchchat/utils/preprocessors.py @@ -0,0 +1,80 @@ +import torch +import torchvision as tv +from torchvision import transforms as tvT +from PIL import Image +import os + +from typing import List + + +def llava_image_preprocess( + image: Image, + *, + target_h: int = 336, + target_w: int = 336, + rescale_factor: float = 0.00392156862745098, + image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], + image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.bfloat16, + ) -> torch.Tensor: + """ + Preprocess an image by resizing it to fit a target height and width, + padding with median RGB value to make a square, scaling, and normalizing. + + Args: + img_address (str): Address of the local image file will be forwarded to the model. + target_h (int): Target height. + target_w (int): Target width. + rescale_factor (float): Rescaling factor. + image_mean (list): Mean values for normalization. + image_std (list): Standard deviation values for normalization. + + Returns: + torch.Tensor: Preprocessed image tensor. + + Raises: + FileNotFoundError: If the image file does not exist. + ValueError: If the target height or width is not positive. + """ + + # Check if the target height and width are positive + if target_h <= 0 or target_w <= 0: + raise ValueError("Target height and width must be positive") + + # Convert the image to a tensor + img = tvT.functional.pil_to_tensor(image) + + # Calculate the height and width ratios + ratio_h = img.shape[1] / target_h + ratio_w = img.shape[2] / target_w + + # Resize the image to fit in a target_h x target_w canvas + ratio = max(ratio_h, ratio_w) + output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) + img = tvT.Resize(size=output_size)(img) + + # Pad the image with median RGB value to make a square + l_pad = (target_w - img.shape[2]) // 2 + t_pad = (target_h - img.shape[1]) // 2 + r_pad = -((target_w - img.shape[2]) // -2) + b_pad = -((target_h - img.shape[1]) // -2) + + torch._check(l_pad >= 0) + torch._check(t_pad >= 0) + torch._check(r_pad >= 0) + torch._check(b_pad >= 0) + + # Pad the image + resized = torch.nn.functional.pad( + img, + (l_pad, r_pad, t_pad, b_pad), + ) + + # Scale the image + scaled = resized * rescale_factor + + # Normalize the image + normed = tvT.Normalize(image_mean, image_std)(scaled) + + return normed.unsqueeze(0).to(device).to(dtype)