Skip to content

Commit c91330c

Browse files
authored
Merge pull request #21 from fleXible-public/feat/pythonic_code
refactor: pythonic code style, tests for Function class
2 parents 4998f2a + a3b031c commit c91330c

File tree

2 files changed

+118
-47
lines changed

2 files changed

+118
-47
lines changed

custom_components/pyscript/function.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -143,49 +143,39 @@ async def service_call(cls, domain, name, **kwargs):
143143
async def service_completions(cls, root):
144144
"""Return possible completions of HASS services."""
145145
words = set()
146-
services = await async_get_all_descriptions(cls.hass)
146+
services = cls.hass.services.async_services()
147147
num_period = root.count(".")
148148
if num_period == 1:
149-
domain, srv_root = root.split(".")
149+
domain, svc_root = root.split(".")
150150
if domain in services:
151-
for srv in services[domain].keys():
152-
if srv.lower().startswith(srv_root):
153-
words.add(f"{domain}.{srv}")
151+
words |= {f"{domain}.{svc}" for svc in services[domain] if svc.lower().startswith(svc_root)}
154152
elif num_period == 0:
155-
for domain in services.keys():
156-
if domain.lower().startswith(root):
157-
words.add(domain)
153+
words |= {domain for domain in services if domain.lower().startswith(root)}
154+
158155
return words
159156

160157
@classmethod
161158
async def func_completions(cls, root):
162159
"""Return possible completions of functions."""
163-
words = set()
164-
funcs = cls.functions.copy()
165-
funcs.update(cls.ast_functions)
166-
for name in funcs.keys():
167-
if name.lower().startswith(root):
168-
words.add(name)
160+
funcs = {**cls.functions, **cls.ast_functions}
161+
words = {name for name in funcs if name.lower().startswith(root)}
162+
169163
return words
170164

171165
@classmethod
172166
def register(cls, funcs):
173167
"""Register functions to be available for calling."""
174-
for name, func in funcs.items():
175-
cls.functions[name] = func
168+
cls.functions.update(funcs)
176169

177170
@classmethod
178171
def register_ast(cls, funcs):
179172
"""Register functions that need ast context to be available for calling."""
180-
for name, func in funcs.items():
181-
cls.ast_functions[name] = func
173+
cls.ast_functions.update(funcs)
182174

183175
@classmethod
184176
def install_ast_funcs(cls, ast_ctx):
185177
"""Install ast functions into the local symbol table."""
186-
sym_table = {}
187-
for name, func in cls.ast_functions.items():
188-
sym_table[name] = func(ast_ctx)
178+
sym_table = {name: func(ast_ctx) for name, func in cls.ast_functions.items()}
189179
ast_ctx.set_local_sym_table(sym_table)
190180

191181
@classmethod
@@ -194,17 +184,20 @@ def get(cls, name):
194184
func = cls.functions.get(name, None)
195185
if func:
196186
return func
197-
parts = name.split(".", 1)
198-
if len(parts) != 2:
187+
188+
name_parts = name.split(".")
189+
if len(name_parts) != 2:
199190
return None
200-
domain = parts[0]
201-
service = parts[1]
202-
if not cls.hass.services.has_service(domain, service):
191+
192+
domain, service = name_parts
193+
if not cls.service_has_service(domain, service):
203194
return None
204195

205196
async def service_call(*args, **kwargs):
206197
await cls.hass.services.async_call(domain, service, kwargs)
207198

199+
# service_call = functools.partial(cls.service_call, domain, service)
200+
208201
return service_call
209202

210203
@classmethod
@@ -218,17 +211,14 @@ async def run_coro(cls, coro):
218211
try:
219212
await coro
220213
except asyncio.CancelledError:
221-
if task in cls.unique_task2name:
222-
cls.unique_name2task.pop(cls.unique_task2name[task], None)
223-
cls.unique_task2name.pop(task, None)
224-
cls.our_tasks.discard(task)
225214
raise
226215
except Exception: # pylint: disable=broad-except
227216
_LOGGER.error("run_coro: %s", traceback.format_exc(-1))
228-
if task in cls.unique_task2name:
229-
cls.unique_name2task.pop(cls.unique_task2name[task], None)
230-
cls.unique_task2name.pop(task, None)
231-
cls.our_tasks.discard(task)
217+
finally:
218+
if task in cls.unique_task2name:
219+
del cls.unique_name2task[cls.unique_task2name[task]]
220+
del cls.unique_task2name[task]
221+
cls.our_tasks.discard(task)
232222

233223
@classmethod
234224
def create_task(cls, coro):
Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,102 @@
11
"""Test the pyscript component."""
2-
from ast import literal_eval
32
import asyncio
4-
from datetime import datetime as dt
53
import pathlib
64
import time
5+
from ast import literal_eval
6+
from datetime import datetime as dt
77

8-
from custom_components.pyscript.const import DOMAIN
9-
import custom_components.pyscript.trigger as trigger
10-
8+
import pytest
119
from homeassistant import loader
1210
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED
1311
from homeassistant.setup import async_setup_component
12+
from pytest_homeassistant.async_mock import MagicMock, Mock, mock_open, patch
13+
14+
import custom_components.pyscript.trigger as trigger
15+
from custom_components.pyscript.const import DOMAIN
16+
from custom_components.pyscript.function import Function
17+
18+
19+
@pytest.fixture()
20+
def ast_functions():
21+
return {
22+
"domain_ast.func_name": lambda ast_ctx: ast_ctx.func(),
23+
"domain_ast.other_func": lambda ast_ctx: ast_ctx.func(),
24+
}
25+
26+
27+
@pytest.fixture
28+
def functions():
29+
mock_func = Mock()
30+
31+
return {
32+
"domain.func_1": mock_func,
33+
"domain.func_2": mock_func,
34+
"helpers.get_today": mock_func,
35+
"helpers.entity_id": mock_func,
36+
}
37+
1438

15-
from pytest_homeassistant.async_mock import mock_open, patch
39+
@pytest.fixture
40+
def services():
41+
return {
42+
"domain": {"turn_on": None, "turn_off": None, "toggle": None},
43+
"helpers": {"set_state": None, "restart": None},
44+
}
45+
46+
47+
def test_install_ast_funcs(ast_functions):
48+
ast_ctx = MagicMock()
49+
ast_ctx.func.return_value = "ok"
50+
51+
with patch.object(Function, "ast_functions", ast_functions):
52+
Function.install_ast_funcs(ast_ctx)
53+
assert len(ast_ctx.method_calls) == 3
54+
55+
56+
@pytest.mark.parametrize(
57+
"root,expected",
58+
[
59+
("helpers", {"helpers.entity_id", "helpers.get_today"}),
60+
(
61+
"domain",
62+
{
63+
"domain.func_2",
64+
"domain_ast.func_name",
65+
"domain_ast.other_func",
66+
"domain.func_1",
67+
},
68+
),
69+
("domain_", {"domain_ast.func_name", "domain_ast.other_func"}),
70+
("domain_ast.func", {"domain_ast.func_name"}),
71+
("no match", set()),
72+
],
73+
ids=lambda x: x if not isinstance(x, (set,)) else f"set({len(x)})",
74+
)
75+
async def test_func_completions(ast_functions, functions, root, expected):
76+
with patch.object(Function, "ast_functions", ast_functions), patch.object(
77+
Function, "functions", functions
78+
):
79+
words = await Function.func_completions(root)
80+
assert words == expected
81+
82+
83+
@pytest.mark.parametrize(
84+
"root,expected",
85+
[
86+
("do", {"domain"}),
87+
("domain.t", {"domain.toggle", "domain.turn_on", "domain.turn_off"}),
88+
("domain.turn", {"domain.turn_on", "domain.turn_off"}),
89+
("helpers.set", {"helpers.set_state"}),
90+
("no match", set()),
91+
],
92+
ids=lambda x: x if not isinstance(x, (set,)) else f"set({len(x)})",
93+
)
94+
async def test_service_completions(root, expected, hass, services):
95+
with patch.object(
96+
hass.services, "async_services", return_value=services
97+
), patch.object(Function, "hass", hass):
98+
words = await Function.service_completions(root)
99+
assert words == expected
16100

17101

18102
async def setup_script(hass, notify_q, now, source):
@@ -34,14 +118,10 @@ async def setup_script(hass, notify_q, now, source):
34118

35119
with patch(
36120
"homeassistant.loader.async_get_integration", return_value=integration,
37-
), patch(
38-
"custom_components.pyscript.os.path.isdir", return_value=True
39-
), patch(
121+
), patch("custom_components.pyscript.os.path.isdir", return_value=True), patch(
40122
"custom_components.pyscript.glob.iglob", return_value=scripts
41123
), patch(
42-
"custom_components.pyscript.open",
43-
mock_open(read_data=source),
44-
create=True,
124+
"custom_components.pyscript.open", mock_open(read_data=source), create=True,
45125
), patch(
46126
"custom_components.pyscript.trigger.dt_now", return_value=now
47127
):
@@ -157,7 +237,8 @@ def func4(trigger_type=None, event_type=None, **kwargs):
157237
seq_num += 1
158238
res = task.wait_until(state_trigger="pyscript.f4var2 == '10'", timeout=10)
159239
log.info(f"func4 trigger_type = {res}")
160-
pyscript.done = [seq_num, res, pyscript.setVar1, pyscript.setVar1.attr1, state.get("pyscript.setVar1.attr2"), pyscript.setVar2, state.get("pyscript.setVar3")]
240+
pyscript.done = [seq_num, res, pyscript.setVar1, pyscript.setVar1.attr1, state.get("pyscript.setVar1.attr2"),
241+
pyscript.setVar2, state.get("pyscript.setVar3")]
161242
162243
seq_num += 1
163244
#

0 commit comments

Comments
 (0)