You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L79
# Copied from https://github.com/huggingface/transformers/blob/6966fa190172b48b2fb46fe4552a13b943e692cf/src/transformers/models/llava_next/modeling_llava_next.py#L479
173
+
defget_image_features(
174
+
self,
175
+
pixel_values: torch.FloatTensor,
176
+
image_sizes: torch.Tensor,
177
+
vision_feature_layer: Union[int, List[int]],
178
+
vision_feature_select_strategy: str,
179
+
):
180
+
"""
181
+
Obtains image last hidden states from the vision tower and apply multimodal projection.
182
+
183
+
Args:
184
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
185
+
The tensors corresponding to the input images.
186
+
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
187
+
Actual image size of each images (H, W).
188
+
vision_feature_layer (`Union[int, List[int]]`):
189
+
The index of the layer to select the vision feature. If multiple indices are provided,
190
+
the vision feature of the corresponding indices will be concatenated to form the
191
+
vision features.
192
+
vision_feature_select_strategy (`str`):
193
+
The feature selection strategy used to select the vision feature from the vision backbone.
194
+
Can be one of `"default"` or `"full"`
195
+
Returns:
196
+
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
197
+
and are of shape `(num_patches, image_length, embed_dim)`).
198
+
"""
199
+
# ! infer image_num_patches from image_sizes
200
+
image_num_patches= [
201
+
image_size_to_num_patches(
202
+
image_size=imsize,
203
+
grid_pinpoints=self.config.image_grid_pinpoints,
204
+
patch_size=self.config.vision_config.image_size,
205
+
)
206
+
forimsizeinimage_sizes
207
+
]
208
+
ifpixel_values.dim() ==5:
209
+
# stacked if input is (batch_size, num_patches, num_channels, height, width)
0 commit comments