Skip to content

Commit 1c14e75

Browse files
addres a few PR comments.
Signed-off-by: Albert van Houten <albert.van.houten@intel.com>
1 parent 6318155 commit 1c14e75

File tree

2 files changed

+19
-20
lines changed

2 files changed

+19
-20
lines changed

library/src/otx/backend/native/tools/tile_merge.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ def _merge_entities(
103103
def merge(
104104
self,
105105
batch_tile_preds: list[OTXPredBatch],
106-
batch_tile_attrs: list[list[dict]],
106+
batch_tile_infos: list[list[dict]],
107107
) -> list[OTXPredItem]:
108108
"""Merge batch tile predictions to a list of full-size prediction data entities.
109109
110110
Args:
111111
batch_tile_preds (list): list of tile predictions.
112-
batch_tile_attrs (list): list of tile attributes.
112+
batch_tile_infos (list): list of tile attributes.
113113
"""
114114
raise NotImplementedError
115115

@@ -139,44 +139,44 @@ class DetectionTileMerge(TileMerge):
139139
def merge(
140140
self,
141141
batch_tile_preds: list[OTXPredBatch],
142-
batch_tile_attrs: list[list[TileInfo]],
142+
batch_tile_infos: list[list[TileInfo]],
143143
) -> list[OTXPredItem]:
144144
"""Merge batch tile predictions to a list of full-size prediction data entities.
145145
146146
Args:
147147
batch_tile_preds (list): detection tile predictions.
148-
batch_tile_attrs (list): detection tile attributes.
148+
batch_tile_infos (list): detection tile attributes.
149149
150150
"""
151151
entities_to_merge = defaultdict(list)
152152
img_ids = []
153153
explain_mode = self.explain_mode
154154

155-
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
155+
for tile_preds, tile_infos in zip(batch_tile_preds, batch_tile_infos, strict=True):
156156
if tile_preds.imgs_info is None or tile_preds.bboxes is None:
157157
msg = "imgs_info or bboxes is None"
158158
raise ValueError(msg)
159-
batch_size = len(tile_attrs)
159+
batch_size = len(tile_infos)
160160
for i in range(batch_size):
161161
if tile_preds.imgs_info[i] is None:
162162
msg = "imgs_info is None"
163163
raise ValueError(msg)
164164
tile_img_info = tile_preds.imgs_info[i]
165-
tile_attr = tile_attrs[i]
165+
tile_info = tile_infos[i]
166166
tile_s_map = tile_preds.saliency_map[i] if tile_preds.saliency_map is not None else None
167167
tile_f_vect = tile_preds.feature_vector[i] if tile_preds.feature_vector is not None else None
168168

169169
tile_bboxes = tile_preds.bboxes[i] if tile_preds.bboxes[i].numel() > 0 else None
170-
offset_x = tile_attr.x
171-
offset_y = tile_attr.y
170+
offset_x = tile_info.x
171+
offset_y = tile_info.y
172172
if tile_bboxes is not None:
173173
tile_bboxes[:, 0::2] += offset_x
174174
tile_bboxes[:, 1::2] += offset_y
175175

176-
tile_id = tile_attr.source_sample_idx
176+
tile_id = tile_info.source_sample_idx
177177
if tile_id not in img_ids:
178178
img_ids.append(tile_id)
179-
tile_img_info.padding = [tile_attr.x, tile_attr.y, tile_attr.width, tile_attr.height] # type: ignore[union-attr]
179+
tile_img_info.padding = [tile_info.x, tile_info.y, tile_info.width, tile_info.height] # type: ignore[union-attr]
180180

181181
det_pred_entity = OTXPredItem(
182182
image=torch.empty(3, *tile_img_info.ori_shape), # type: ignore[union-attr]
@@ -499,13 +499,13 @@ def __init__(
499499
def merge(
500500
self,
501501
batch_tile_preds: list[OTXPredBatch],
502-
batch_tile_attrs: list[list[TileInfo]],
502+
batch_tile_infos: list[list[TileInfo]],
503503
) -> list[OTXPredItem]:
504504
"""Merge batch tile predictions to a list of full-size prediction data entities.
505505
506506
Args:
507507
batch_tile_preds (list[SegBatchPredEntity]): segmentation tile predictions.
508-
batch_tile_attrs (list[list[dict]]): segmentation tile attributes.
508+
batch_tile_infos (list[list[dict]]): segmentation tile attributes.
509509
510510
Returns:
511511
list[TorchPredItem]: List of full-size prediction data entities after merging.
@@ -514,7 +514,7 @@ def merge(
514514
img_ids = []
515515
explain_mode = self.explain_mode
516516

517-
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
517+
for tile_preds, tile_infos in zip(batch_tile_preds, batch_tile_infos):
518518
batch_size = tile_preds.batch_size
519519
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
520520
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
@@ -528,8 +528,8 @@ def merge(
528528
msg = "The predicted masks are not provided."
529529
raise ValueError(msg)
530530

531-
for tile_attr, tile_img_info, tile_masks, tile_s_map, tile_f_vect in zip(
532-
tile_attrs,
531+
for tile_info, tile_img_info, tile_masks, tile_s_map, tile_f_vect in zip(
532+
tile_infos,
533533
tile_preds.imgs_info,
534534
tile_preds.masks,
535535
saliency_maps,
@@ -539,10 +539,10 @@ def merge(
539539
msg = f"Image information is not provided : {tile_preds.imgs_info}."
540540
raise ValueError(msg)
541541

542-
tile_id = tile_attr.source_sample_idx
542+
tile_id = tile_info.source_sample_idx
543543
if tile_id not in img_ids:
544544
img_ids.append(tile_id)
545-
tile_img_info.padding = (tile_attr.x, tile_attr.y, tile_attr.width, tile_attr.height)
545+
tile_img_info.padding = (tile_info.x, tile_info.y, tile_info.width, tile_info.height)
546546
seg_pred_entity = OTXPredItem(
547547
image=torch.empty((3, *tile_img_info.ori_shape)),
548548
img_info=tile_img_info,

library/src/otx/data/factory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from otx.types.transformer_libs import TransformLibType
1414

1515
from .dataset.base import OTXDataset, Transforms
16-
from .dataset.base import OTXDataset as OTXDatasetNew
1716

1817
if TYPE_CHECKING:
1918
from datumaro.components.dataset import Dataset as DmDataset
@@ -52,7 +51,7 @@ def create(
5251
include_polygons: bool = False,
5352
# TODO(gdlg): Add support for ignore_index again
5453
ignore_index: int = 255, # noqa: ARG003
55-
) -> OTXDataset | OTXDatasetNew:
54+
) -> OTXDataset:
5655
"""Create OTXDataset."""
5756
transforms = TransformLibFactory.generate(cfg_subset)
5857
common_kwargs = {

0 commit comments

Comments
 (0)