-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathtest_extension_plugins.py
More file actions
132 lines (98 loc) · 3.65 KB
/
test_extension_plugins.py
File metadata and controls
132 lines (98 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import pytest
from metaflow.extension_support import plugins as plugins_module
from metaflow.extension_support.plugins import (
get_plugin,
get_plugin_name,
get_trampoline_cli_names,
merge_lists,
)
class TestMergeLists:
def test_overrides_win_over_base(self):
class Item:
def __init__(self, name):
self.name = name
base = [Item("a"), Item("b"), Item("c")]
overrides = [Item("b_new"), Item("d")]
merge_lists(base, overrides, "name")
names = [item.name for item in base]
assert "b_new" in names
assert "d" in names
assert "a" in names
assert "c" in names
def test_empty_overrides_keeps_base(self):
class Item:
def __init__(self, name):
self.name = name
base = [Item("a"), Item("b")]
overrides = []
merge_lists(base, overrides, "name")
names = [item.name for item in base]
assert names == ["a", "b"]
def test_empty_base_uses_overrides(self):
class Item:
def __init__(self, name):
self.name = name
base = []
overrides = [Item("x"), Item("y")]
merge_lists(base, overrides, "name")
names = [item.name for item in base]
assert names == ["x", "y"]
def test_both_empty(self):
base = []
overrides = []
merge_lists(base, overrides, "name")
assert base == []
def test_no_duplicates_in_result(self):
class Item:
def __init__(self, name):
self.name = name
base = [Item("a"), Item("b")]
overrides = [Item("a"), Item("b")]
merge_lists(base, overrides, "name")
names = set(item.name for item in base)
assert names == {"a", "b"}
class TestGetPluginName:
def test_step_decorator_name_extraction(self):
class MockDecorator:
name = "my_decorator"
assert get_plugin_name("step_decorator", MockDecorator()) == "my_decorator"
def test_environment_type_extraction(self):
class MockEnv:
TYPE = "conda"
assert get_plugin_name("environment", MockEnv()) == "conda"
def test_sidecar_returns_none(self):
class MockSidecar:
pass
assert get_plugin_name("sidecar", MockSidecar()) is None
def test_cli_single_command(self):
class MockCLI:
commands = {"run"}
assert get_plugin_name("cli", MockCLI()) == "run"
def test_cli_too_many_commands(self):
class MockCLI:
commands = {"run", "step", "start"}
result = get_plugin_name("cli", MockCLI())
assert "too many commands" in result
def test_unknown_category_raises_key_error(self):
class MockPlugin:
name = "test"
with pytest.raises(KeyError):
get_plugin_name("nonexistent_category", MockPlugin())
class TestGetPlugin:
def test_get_plugin_raises_value_error_for_invalid_path(self):
with pytest.raises(ValueError, match="Cannot locate"):
get_plugin("step_decorator", "nonexistent.module.path.FakeClass", "fake")
def test_get_plugin_successfully_loads_real_class(self):
cls = get_plugin(
"step_decorator",
"metaflow.plugins.retry_decorator.RetryDecorator",
"retry",
)
assert cls.name == "retry"
class TestGetTrampolineCliNames:
def test_returns_frozenset(self):
result = get_trampoline_cli_names()
assert isinstance(result, frozenset)
def test_contains_expected_entries(self):
result = get_trampoline_cli_names()
assert "batch" in result or "kubernetes" in result