Skip to content

Commit 4607f38

Browse files
committed
Merge remote-tracking branch 'origin/main' into check_tests
2 parents 94ced43 + 8300607 commit 4607f38

File tree

10 files changed

+26
-14
lines changed

10 files changed

+26
-14
lines changed

.github/workflows/pre-commit-workflow.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- name: Set up Python
2222
uses: actions/setup-python@v3
2323
with:
24-
python-version: "3.8.0"
24+
python-version: "3.10.0"
2525
- name: Install test dependencies
2626
run: |
2727
apt-get get update && apt-get install cmake

.pre-commit-config.yaml

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,58 @@ default_stages:
33
repos:
44
# general hooks to verify or beautify code
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v3.3.0
6+
rev: v5.0.0
77
hooks:
88
- id: check-added-large-files
99
args: [--maxkb=5000]
10+
stages: [commit]
1011
- id: trailing-whitespace
12+
stages: [commit]
1113
- id: check-json
14+
stages: [commit]
1215
- id: check-merge-conflict
16+
stages: [commit]
1317
- id: check-xml
18+
stages: [commit]
1419
- id: check-yaml
20+
stages: [commit]
1521
- id: detect-private-key
22+
stages: [commit]
1623
- id: mixed-line-ending
24+
stages: [commit]
1725
- id: pretty-format-json
1826
args: [--autofix]
1927
exclude: \.ipynb$
28+
stages: [commit]
2029

2130

2231
# autoformat code with black formatter
2332
- repo: https://github.com/psf/black
24-
rev: 22.3.0
33+
rev: 25.1.0
2534
hooks:
2635
- id: black
2736
args: [-l 120]
2837

2938

3039
# beautify and sort imports
3140
- repo: https://github.com/pycqa/isort
32-
rev: 5.12.0
41+
rev: 6.0.0
3342
hooks:
3443
- id: isort
3544
args: ["--profile", "black"]
3645

3746

3847
# check code style
3948
- repo: https://github.com/pycqa/flake8
40-
rev: 3.8.4
49+
rev: 7.1.1
4150
hooks:
4251
- id: flake8
4352
exclude: __init__.py
4453

4554

4655
# static type checking
4756
- repo: https://github.com/pre-commit/mirrors-mypy
48-
rev: v0.910
57+
rev: v1.14.1
4958
hooks:
5059
- id: mypy
5160
additional_dependencies: [types-requests==2.25.9]

ci/requirements_tests.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
pytest>=7.2
2-
pre-commit==2.15.0
2+
pre-commit==3.5.0
33
python-dotenv>=0.17.0

oml/miners/inbatch_nhard_tri.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def _sample_from_distmat(
123123
hardest_negative = ids_smallest_distance[idx_anch_neg_reduced, idx_neg_sorted_by_dist]
124124
idx_anch_neg = all_ids_reduced[idx_anch_neg_reduced]
125125

126-
ids_a = []
127-
ids_p = []
128-
ids_n = []
126+
ids_a: List[int] = []
127+
ids_p: List[int] = []
128+
ids_n: List[int] = []
129129

130130
for idx_anch in torch.arange(len(labels), device=distmat.device)[torch.logical_not(ignore_anchor_mask)]:
131131
positives = hardest_positive[idx_anch_pos == idx_anch][self.positive_slice]

tests/test_oml/test_ddp/test_loader_patcher.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .utils import run_in_ddp
1818

1919

20+
@pytest.mark.skip(reason="Dead locks may appear when running in CI")
2021
@pytest.mark.long
2122
@pytest.mark.parametrize("n_labels_sampler", [2, 5])
2223
@pytest.mark.parametrize("n_instances_sampler", [2, 5])
@@ -155,6 +156,7 @@ def check_patching_balance_batch_sampler(
155156
assert len(set(outputs_from_epochs)) == len(outputs_from_epochs)
156157

157158

159+
@pytest.mark.skip(reason="Dead locks may appear when running in CI")
158160
@pytest.mark.long
159161
@pytest.mark.parametrize("shuffle", [True, False])
160162
@pytest.mark.parametrize("drop_last", [True, False])

tests/test_oml/test_miners/shared_checkers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def check_triplets_consistency(
1919

2020
assert num_sampled_tri == len(ids_pos) == len(ids_neg)
2121

22-
for (i_a, i_p, i_n) in zip(ids_anchor, ids_pos, ids_neg):
22+
for i_a, i_p, i_n in zip(ids_anchor, ids_pos, ids_neg):
2323
assert len({i_a, i_p, i_n}) == 3
2424
assert labels[i_a] == labels[i_p]
2525
assert labels[i_a] != labels[i_n]

tests/test_oml/test_miners/test_inbatch_all_tri.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_all_triplets_miner(features_and_labels: List[Tuple[torch.Tensor, List[i
3232

3333
@pytest.mark.long
3434
def test_compare_all_triplets_miner_with_naive_version(
35-
features_and_labels: List[Tuple[torch.Tensor, List[int]]]
35+
features_and_labels: List[Tuple[torch.Tensor, List[int]]],
3636
) -> None:
3737
max_tri = sys.maxsize
3838
miner = AllTripletsMiner(max_output_triplets=max_tri)

tests/test_oml/test_models/test_visualisations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_numpy_image() -> np.ndarray:
3636
def test_visualisation(draw_function: Any, image: Union[np.ndarray, Image.Image]) -> None:
3737
image_modified = draw_function(image)
3838

39-
assert type(image_modified) == type(image)
39+
assert isinstance(image_modified, type(image))
4040

4141
image_modified = np.array(image_modified)
4242

tests/test_runs/test_ddp_cases/test_experiments_equality.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_epochs_are_equal() -> None:
8484

8585
def is_equal_models(model1: nn.Module, model2: nn.Module) -> Tuple[bool, List[torch.Tensor]]:
8686
for module1, module2 in zip(model1.modules(), model2.modules()):
87-
assert type(module1) == type(module2)
87+
assert isinstance(module1, type(module2))
8888
if isinstance(module1, nn.Linear):
8989
if not torch.all(torch.isclose(module1.weight, module2.weight, atol=TORCH_EPS)):
9090
return False, [module1.weight, module2.weight]

tests/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ def forward(self, x): # type: ignore
3636
x = x.float()
3737
return x
3838

39+
@property
3940
def feat_dim(self) -> int:
4041
return self.model.embedding_dim

0 commit comments

Comments
 (0)