Skip to content

Commit bd45610

Browse files
committed
Made #495 service repsonse backward compatible
1 parent 9522331 commit bd45610

File tree

5 files changed

+65
-30
lines changed

5 files changed

+65
-30
lines changed

custom_components/pyscript/const.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
"""Define pyscript-wide constants."""
22

3+
#
4+
# 2023.7 supports service response; handle older versions by defaulting enum
5+
# Should eventually deprecate this and just use SupportsResponse import
6+
#
7+
try:
8+
from homeassistant.core import SupportsResponse
9+
10+
SERVICE_RESPONSE_NONE = SupportsResponse.NONE
11+
SERVICE_RESPONSE_OPTIONAL = SupportsResponse.OPTIONAL
12+
SERVICE_RESPONSE_ONLY = SupportsResponse.ONLY
13+
except ImportError:
14+
SERVICE_RESPONSE_NONE = None
15+
SERVICE_RESPONSE_OPTIONAL = None
16+
SERVICE_RESPONSE_ONLY = None
17+
318
DOMAIN = "pyscript"
419

520
CONFIG_ENTRY = "config_entry"

custom_components/pyscript/eval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import yaml
1717

18-
from homeassistant.core import SupportsResponse
1918
from homeassistant.const import SERVICE_RELOAD
2019
from homeassistant.helpers.service import async_set_service_schema
2120

@@ -26,6 +25,7 @@
2625
DOMAIN,
2726
LOGGER_PATH,
2827
SERVICE_JUPYTER_KERNEL_START,
28+
SERVICE_RESPONSE_NONE,
2929
)
3030
from .function import Function
3131
from .state import State
@@ -505,7 +505,11 @@ async def do_service_call(func, ast_ctx, data):
505505
if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START):
506506
raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service")
507507
Function.service_register(
508-
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SupportsResponse.NONE)
508+
trig_ctx_name,
509+
domain,
510+
name,
511+
pyscript_service_factory(func_name, self),
512+
dec_kwargs.get("supports_response", SERVICE_RESPONSE_NONE),
509513
)
510514
async_set_service_schema(Function.hass, domain, name, service_desc)
511515
self.trigger_service.add(srv_name)

custom_components/pyscript/function.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import logging
55
import traceback
66

7-
from homeassistant.core import Context, SupportsResponse
7+
from homeassistant.core import Context
88

9-
from .const import LOGGER_PATH
9+
from .const import LOGGER_PATH, SERVICE_RESPONSE_NONE, SERVICE_RESPONSE_ONLY
1010

1111
_LOGGER = logging.getLogger(LOGGER_PATH + ".function")
1212

@@ -332,14 +332,7 @@ async def service_call(cls, domain, name, **kwargs):
332332
elif default:
333333
hass_args[keyword] = default
334334

335-
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
336-
hass_args["blocking"] = True
337-
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, name) == SupportsResponse.ONLY:
338-
hass_args["return_response"] = True
339-
if "blocking" not in hass_args:
340-
hass_args["blocking"] = True
341-
342-
return await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
335+
return await cls.hass_services_async_call(domain, name, kwargs, **hass_args)
343336

344337
@classmethod
345338
async def service_completions(cls, root):
@@ -413,19 +406,35 @@ async def service_call(*args, **kwargs):
413406
if len(args) != 0:
414407
raise TypeError(f"service {domain}.{service} takes only keyword arguments")
415408

416-
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
417-
hass_args["blocking"] = True
418-
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
419-
hass_args["return_response"] = True
420-
if "blocking" not in hass_args:
421-
hass_args["blocking"] = True
422-
423-
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
409+
return await cls.hass_services_async_call(domain, service, kwargs, **hass_args)
424410

425411
return service_call
426412

427413
return service_call_factory(domain, service)
428414

415+
@classmethod
416+
async def hass_services_async_call(cls, domain, service, kwargs, **hass_args):
417+
"""Call a hass async service."""
418+
if SERVICE_RESPONSE_ONLY is None:
419+
# backwards compatibility < 2023.7
420+
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
421+
else:
422+
# allow service responses >= 2023.7
423+
if (
424+
"return_response" in hass_args
425+
and hass_args["return_response"]
426+
and "blocking" not in hass_args
427+
):
428+
hass_args["blocking"] = True
429+
elif (
430+
"return_response" not in hass_args
431+
and cls.hass.services.supports_response(domain, service) == SERVICE_RESPONSE_ONLY
432+
):
433+
hass_args["return_response"] = True
434+
if "blocking" not in hass_args:
435+
hass_args["blocking"] = True
436+
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
437+
429438
@classmethod
430439
async def run_coro(cls, coro, ast_ctx=None):
431440
"""Run coroutine task and update unique task on start and exit."""
@@ -466,7 +475,9 @@ def create_task(cls, coro, ast_ctx=None):
466475
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))
467476

468477
@classmethod
469-
def service_register(cls, global_ctx_name, domain, service, callback, supports_response = SupportsResponse.NONE):
478+
def service_register(
479+
cls, global_ctx_name, domain, service, callback, supports_response=SERVICE_RESPONSE_NONE
480+
):
470481
"""Register a new service callback."""
471482
key = f"{domain}.{service}"
472483
if key not in cls.service_cnt:
@@ -478,7 +489,12 @@ def service_register(cls, global_ctx_name, domain, service, callback, supports_r
478489
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
479490
)
480491
cls.service_cnt[key] += 1
481-
cls.hass.services.async_register(domain, service, callback, supports_response = supports_response)
492+
if SERVICE_RESPONSE_ONLY is None:
493+
# backwards compatibility < 2023.7
494+
cls.hass.services.async_register(domain, service, callback)
495+
else:
496+
# allow service responses >= 2023.7
497+
cls.hass.services.async_register(domain, service, callback, supports_response=supports_response)
482498

483499
@classmethod
484500
def service_remove(cls, global_ctx_name, domain, service):

custom_components/pyscript/state.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import logging
55

6-
from homeassistant.core import Context, SupportsResponse
6+
from homeassistant.core import Context
77
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
88
from homeassistant.helpers.service import async_get_all_descriptions
99

@@ -308,13 +308,7 @@ async def service_call(*args, **kwargs):
308308
elif len(args) != 0:
309309
raise TypeError(f"service {domain}.{service} takes no positional arguments")
310310

311-
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
312-
hass_args["blocking"] = True
313-
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
314-
hass_args["return_response"] = True
315-
if "blocking" not in hass_args:
316-
hass_args["blocking"] = True
317-
311+
# return await Function.hass_services_async_call(domain, service, kwargs, **hass_args)
318312
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
319313

320314
return service_call

tests/test_state.py

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66

7+
from custom_components.pyscript.function import Function
78
from custom_components.pyscript.state import State
89
from homeassistant.core import Context
910
from homeassistant.helpers.state import State as HassState
@@ -23,6 +24,7 @@ async def test_service_call(hass):
2324
hass.services, "async_call"
2425
) as call:
2526
State.init(hass)
27+
Function.init(hass)
2628
await State.get_service_params()
2729

2830
func = State.get("test.entity.test")
@@ -45,3 +47,7 @@ async def test_service_call(hass):
4547
{"other_service_data": "test", "entity_id": "test.entity"},
4648
)
4749
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}
50+
51+
# Stop all tasks to avoid conflicts with other tests
52+
await Function.waiter_stop()
53+
await Function.reaper_stop()

0 commit comments

Comments
 (0)