Skip to content

Commit 212cc50

Browse files
committed
Added support for service responses when calling or creating services.
1 parent 716ffd7 commit 212cc50

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

custom_components/pyscript/eval.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import yaml
1717

18+
from homeassistant.core import SupportsResponse
1819
from homeassistant.const import SERVICE_RELOAD
1920
from homeassistant.helpers.service import async_set_service_schema
2021

@@ -377,6 +378,7 @@ async def trigger_init(self, trig_ctx, func_name):
377378
"time_trigger": {"kwargs": {dict}},
378379
"task_unique": {"kill_me": {bool, int}},
379380
"time_active": {"hold_off": {int, float}},
381+
"service": {"supports_response": {str}},
380382
"state_trigger": {
381383
"kwargs": {dict},
382384
"state_hold": {int, float},
@@ -485,11 +487,14 @@ async def pyscript_service_handler(call):
485487
func_args.update(call.data)
486488

487489
async def do_service_call(func, ast_ctx, data):
488-
await func.call(ast_ctx, **data)
490+
retval = await func.call(ast_ctx, **data)
489491
if ast_ctx.get_exception_obj():
490492
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
493+
return retval
491494

492-
Function.create_task(do_service_call(func, ast_ctx, func_args))
495+
task = Function.create_task(do_service_call(func, ast_ctx, func_args))
496+
await task
497+
return task.result()
493498

494499
return pyscript_service_handler
495500

@@ -500,7 +505,7 @@ async def do_service_call(func, ast_ctx, data):
500505
if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START):
501506
raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service")
502507
Function.service_register(
503-
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self)
508+
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SupportsResponse.NONE)
504509
)
505510
async_set_service_schema(Function.hass, domain, name, service_desc)
506511
self.trigger_service.add(srv_name)

custom_components/pyscript/function.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import traceback
66

7-
from homeassistant.core import Context
7+
from homeassistant.core import Context, SupportsResponse
88

99
from .const import LOGGER_PATH
1010

@@ -324,14 +324,22 @@ async def service_call(cls, domain, name, **kwargs):
324324
for keyword, typ, default in [
325325
("context", [Context], cls.task2context.get(curr_task, None)),
326326
("blocking", [bool], None),
327+
("return_response", [bool], None),
327328
("limit", [float, int], None),
328329
]:
329330
if keyword in kwargs and type(kwargs[keyword]) in typ:
330331
hass_args[keyword] = kwargs.pop(keyword)
331332
elif default:
332333
hass_args[keyword] = default
333334

334-
await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
335+
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
336+
hass_args["blocking"] = True
337+
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, name) == SupportsResponse.ONLY:
338+
hass_args["return_response"] = True
339+
if "blocking" not in hass_args:
340+
hass_args["blocking"] = True
341+
342+
return await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
335343

336344
@classmethod
337345
async def service_completions(cls, root):
@@ -394,6 +402,7 @@ async def service_call(*args, **kwargs):
394402
for keyword, typ, default in [
395403
("context", [Context], cls.task2context.get(curr_task, None)),
396404
("blocking", [bool], None),
405+
("return_response", [bool], None),
397406
("limit", [float, int], None),
398407
]:
399408
if keyword in kwargs and type(kwargs[keyword]) in typ:
@@ -404,7 +413,14 @@ async def service_call(*args, **kwargs):
404413
if len(args) != 0:
405414
raise TypeError(f"service {domain}.{service} takes only keyword arguments")
406415

407-
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
416+
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
417+
hass_args["blocking"] = True
418+
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
419+
hass_args["return_response"] = True
420+
if "blocking" not in hass_args:
421+
hass_args["blocking"] = True
422+
423+
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
408424

409425
return service_call
410426

@@ -450,7 +466,7 @@ def create_task(cls, coro, ast_ctx=None):
450466
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))
451467

452468
@classmethod
453-
def service_register(cls, global_ctx_name, domain, service, callback):
469+
def service_register(cls, global_ctx_name, domain, service, callback, supports_response = SupportsResponse.NONE):
454470
"""Register a new service callback."""
455471
key = f"{domain}.{service}"
456472
if key not in cls.service_cnt:
@@ -462,7 +478,7 @@ def service_register(cls, global_ctx_name, domain, service, callback):
462478
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
463479
)
464480
cls.service_cnt[key] += 1
465-
cls.hass.services.async_register(domain, service, callback)
481+
cls.hass.services.async_register(domain, service, callback, supports_response = supports_response)
466482

467483
@classmethod
468484
def service_remove(cls, global_ctx_name, domain, service):

custom_components/pyscript/state.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import logging
55

6-
from homeassistant.core import Context
6+
from homeassistant.core import Context, SupportsResponse
77
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
88
from homeassistant.helpers.service import async_get_all_descriptions
99

@@ -290,6 +290,7 @@ async def service_call(*args, **kwargs):
290290
for keyword, typ, default in [
291291
("context", [Context], Function.task2context.get(curr_task, None)),
292292
("blocking", [bool], None),
293+
("return_response", [bool], None),
293294
("limit", [float, int], None),
294295
]:
295296
if keyword in kwargs and type(kwargs[keyword]) in typ:
@@ -306,7 +307,15 @@ async def service_call(*args, **kwargs):
306307
kwargs[param_name] = args[0]
307308
elif len(args) != 0:
308309
raise TypeError(f"service {domain}.{service} takes no positional arguments")
309-
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
310+
311+
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
312+
hass_args["blocking"] = True
313+
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
314+
hass_args["return_response"] = True
315+
if "blocking" not in hass_args:
316+
hass_args["blocking"] = True
317+
318+
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
310319

311320
return service_call
312321

0 commit comments

Comments
 (0)