Skip to content

Commit b817d1b

Browse files
authored
Fix wrong indices setting in HLabelInfo (#4044)
* Fix wrong indices setting in label_info * Add unit-test & update for releases
1 parent d6458e6 commit b817d1b

File tree

6 files changed

+43
-3
lines changed

6 files changed

+43
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ All notable changes to this project will be documented in this file.
8888
(<https://github.com/openvinotoolkit/training_extensions/pull/4016>)
8989
- Fix multilabel_accuracy of MixedHLabelAccuracy
9090
(<https://github.com/openvinotoolkit/training_extensions/pull/4042>)
91+
- Fix wrong indices setting in HLabelInfo
92+
(<https://github.com/openvinotoolkit/training_extensions/pull/4044>)
9193

9294
## \[v2.1.0\]
9395

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ In addition to the examples above, please refer to the documentation for tutoria
212212
- Fix num_trials calculation on dataset length less than num_class
213213
- Fix out_features in HierarchicalCBAMClsHead
214214
- Fix multilabel_accuracy of MixedHLabelAccuracy
215+
- Fix wrong indices setting in HLabelInfo
215216

216217
### Known issues
217218

docs/source/guide/release_notes/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Bug fixes
5454
- Fix num_trials calculation on dataset length less than num_class
5555
- Fix out_features in HierarchicalCBAMClsHead
5656
- Fix multilabel_accuracy of MixedHLabelAccuracy
57+
- Fix wrong indices setting in HLabelInfo
5758

5859
v2.1.0 (2024.07)
5960
----------------

src/otx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Copyright (C) 2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
__version__ = "2.2.0rc8"
6+
__version__ = "2.2.0rc9"
77

88
import os
99
from pathlib import Path

src/otx/core/types/label.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def convert_labels_if_needed(
264264
single_label_group_info["class_to_idx"],
265265
)
266266

267+
label_to_idx = {lbl: i for i, lbl in enumerate(merged_class_to_idx.keys())}
268+
267269
return HLabelInfo(
268270
label_names=label_names,
269271
label_groups=all_groups,
@@ -273,7 +275,7 @@ def convert_labels_if_needed(
273275
num_single_label_classes=exclusive_group_info["num_single_label_classes"],
274276
class_to_group_idx=merged_class_to_idx,
275277
all_groups=all_groups,
276-
label_to_idx=dm_label_categories._indices, # noqa: SLF001
278+
label_to_idx=label_to_idx,
277279
label_tree_edges=get_label_tree_edges(dm_label_categories.items),
278280
empty_multiclass_head_indices=[], # consider the label removing case
279281
)

tests/unit/core/types/test_label.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
3+
from __future__ import annotations
34

4-
from otx.core.types.label import NullLabelInfo, SegLabelInfo
5+
from datumaro import LabelCategories
6+
from datumaro.components.annotation import GroupType
7+
from otx.core.types.label import HLabelInfo, NullLabelInfo, SegLabelInfo
58

69

710
def test_as_json(fxt_label_info):
@@ -18,3 +21,34 @@ def test_seg_label_info():
1821
)
1922
assert SegLabelInfo.from_num_classes(1) == SegLabelInfo(["background", "label_0"], [["background", "label_0"]])
2023
assert SegLabelInfo.from_num_classes(0) == NullLabelInfo()
24+
25+
26+
# Unit test
27+
def test_hlabel_info():
28+
labels = [
29+
LabelCategories.Category(name="car", parent="vehicle"),
30+
LabelCategories.Category(name="truck", parent="vehicle"),
31+
LabelCategories.Category(name="plush toy", parent="plush toy"),
32+
LabelCategories.Category(name="No class"),
33+
]
34+
label_groups = [
35+
LabelCategories.LabelGroup(
36+
name="Detection labels___vehicle",
37+
labels=["car", "truck"],
38+
group_type=GroupType.EXCLUSIVE,
39+
),
40+
LabelCategories.LabelGroup(
41+
name="Detection labels___plush toy",
42+
labels=["plush toy"],
43+
group_type=GroupType.EXCLUSIVE,
44+
),
45+
LabelCategories.LabelGroup(name="No class", labels=["No class"], group_type=GroupType.RESTRICTED),
46+
]
47+
dm_label_categories = LabelCategories(items=labels, label_groups=label_groups)
48+
49+
hlabel_info = HLabelInfo.from_dm_label_groups(dm_label_categories)
50+
51+
# Check if class_to_group_idx and label_to_idx have the same keys
52+
assert list(hlabel_info.class_to_group_idx.keys()) == list(
53+
hlabel_info.label_to_idx.keys(),
54+
), "class_to_group_idx and label_to_idx keys do not match"

0 commit comments

Comments
 (0)