Skip to content

Commit 6553a7d

Browse files
authored
Mark pyemd as optional since it does not support Python 3.12 (#1770)
pyemd is only used for data shift analysis and does not seem to be maintained anymore. Most of the code does not use pyemd so it is better to mark it as optional. <!-- Contributing guide: https://github.com/open-edge-platform/datumaro/blob/develop/CONTRIBUTING.md --> ### Summary <!-- Resolves #111 and #222. Depends on #1000 (for series of dependent commits). This PR introduces this capability to make the project better in this and that. - Added this feature - Removed that feature - Fixed the problem #1234 --> ### How to test <!-- Describe the testing procedure for reviewers, if changes are not fully covered by unit tests or manual testing can be complicated. --> ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [ ] I have added unit tests to cover my changes.​ - [ ] I have added integration tests to cover my changes.​ - [ ] I have added the description of my changes into [CHANGELOG](https://github.com/open-edge-platform/datumaro/blob/develop/CHANGELOG.md).​ - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly ### License - [ ] I submit _my code changes_ under the same [MIT License](https://github.com/open-edge-platform/datumaro/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. - [ ] I have updated the license header for each file (see an example below). ```python # Copyright (C) 2025 Intel Corporation # # SPDX-License-Identifier: MIT ```
2 parents 5e62076 + 12ad7df commit 6553a7d

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626
### Bug fixes
2727
- Fix assertion to compare hashkeys against expected value
2828
(<https://github.com/open-edge-platform/datumaro/pull/1641>)
29+
- Mark pyemd as optional since it does not support Python 3.12
30+
(<https://github.com/open-edge-platform/datumaro/pull/1770>)
2931

3032
## Q1 2025 Release 1.10.0
3133

requirements-core.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ tokenizers
4040
# Encryption
4141
cryptography>= 38.03
4242

43-
# Shift analyzer
44-
pyemd
45-
4643
# apache arrow
4744
pyarrow
4845

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def parse_requirements(filename=CORE_REQUIREMENTS_FILE):
8686
"tf": ["tensorflow"],
8787
"tfds": ["tensorflow-datasets<4.9.3", "absl-py>=0.12.0"],
8888
"torch": ["torch", "torchvision"],
89+
"pyemd": ["pyemd"],
8990
"default": DEFAULT_REQUIREMENTS,
9091
},
9192
ext_modules=ext_modules,

src/datumaro/components/shift_analyzer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
from datumaro.util import take_by
1616

1717
if TYPE_CHECKING:
18-
import pyemd
1918
from scipy import linalg, stats
2019

2120
from datumaro.plugins.openvino_plugin import shift_launcher
2221
else:
2322
from datumaro.util.import_util import lazy_import
2423

25-
pyemd = lazy_import("pyemd")
2624
linalg = lazy_import("scipy.linalg")
2725
stats = lazy_import("scipy.stats")
2826
shift_launcher = lazy_import("datumaro.plugins.openvino_plugin.shift_launcher")
@@ -265,5 +263,7 @@ def _earth_mover_distance(
265263
f_concat = np.concatenate([f_s, f_t], axis=0)
266264
distances = np.linalg.norm(f_concat[:, None] - f_concat[None, :], axis=2).astype(np.float64)
267265

266+
import pyemd
267+
268268
emd = pyemd.emd(w_1, w_2, distances)
269269
return np.exp(-gamma * emd).item()

tests/unit/test_shift_analyzer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
from typing import List
3+
from unittest import skipIf
34

45
import numpy as np
56
import pytest
@@ -12,6 +13,13 @@
1213

1314
from ..requirements import Requirements, mark_requirement
1415

16+
try:
17+
import pyemd
18+
19+
import_failed = False
20+
except ImportError:
21+
import_failed = True
22+
1523

1624
@pytest.fixture
1725
def fxt_dataset_ideal():
@@ -57,6 +65,7 @@ def fxt_dataset_different():
5765
return [src_dataset, tgt_dataset]
5866

5967

68+
@skipIf(import_failed, "Failed to import pyemd")
6069
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
6170
@pytest.mark.parametrize(
6271
"fxt_datasets,method,expected",

0 commit comments

Comments
 (0)