4
4
import logging
5
5
import traceback
6
6
7
- from homeassistant .core import Context
7
+ from homeassistant .core import Context , SupportsResponse
8
8
9
9
from .const import LOGGER_PATH
10
10
@@ -324,14 +324,22 @@ async def service_call(cls, domain, name, **kwargs):
324
324
for keyword , typ , default in [
325
325
("context" , [Context ], cls .task2context .get (curr_task , None )),
326
326
("blocking" , [bool ], None ),
327
+ ("return_response" , [bool ], None ),
327
328
("limit" , [float , int ], None ),
328
329
]:
329
330
if keyword in kwargs and type (kwargs [keyword ]) in typ :
330
331
hass_args [keyword ] = kwargs .pop (keyword )
331
332
elif default :
332
333
hass_args [keyword ] = default
333
334
334
- await cls .hass .services .async_call (domain , name , kwargs , ** hass_args )
335
+ if "return_response" in hass_args and hass_args ["return_response" ] == True and "blocking" not in hass_args :
336
+ hass_args ["blocking" ] = True
337
+ elif "return_response" not in hass_args and cls .hass .services .supports_response (domain , name ) == SupportsResponse .ONLY :
338
+ hass_args ["return_response" ] = True
339
+ if "blocking" not in hass_args :
340
+ hass_args ["blocking" ] = True
341
+
342
+ return await cls .hass .services .async_call (domain , name , kwargs , ** hass_args )
335
343
336
344
@classmethod
337
345
async def service_completions (cls , root ):
@@ -394,6 +402,7 @@ async def service_call(*args, **kwargs):
394
402
for keyword , typ , default in [
395
403
("context" , [Context ], cls .task2context .get (curr_task , None )),
396
404
("blocking" , [bool ], None ),
405
+ ("return_response" , [bool ], None ),
397
406
("limit" , [float , int ], None ),
398
407
]:
399
408
if keyword in kwargs and type (kwargs [keyword ]) in typ :
@@ -404,7 +413,14 @@ async def service_call(*args, **kwargs):
404
413
if len (args ) != 0 :
405
414
raise TypeError (f"service { domain } .{ service } takes only keyword arguments" )
406
415
407
- await cls .hass .services .async_call (domain , service , kwargs , ** hass_args )
416
+ if "return_response" in hass_args and hass_args ["return_response" ] == True and "blocking" not in hass_args :
417
+ hass_args ["blocking" ] = True
418
+ elif "return_response" not in hass_args and cls .hass .services .supports_response (domain , service ) == SupportsResponse .ONLY :
419
+ hass_args ["return_response" ] = True
420
+ if "blocking" not in hass_args :
421
+ hass_args ["blocking" ] = True
422
+
423
+ return await cls .hass .services .async_call (domain , service , kwargs , ** hass_args )
408
424
409
425
return service_call
410
426
@@ -450,7 +466,7 @@ def create_task(cls, coro, ast_ctx=None):
450
466
return cls .hass .loop .create_task (cls .run_coro (coro , ast_ctx = ast_ctx ))
451
467
452
468
@classmethod
453
- def service_register (cls , global_ctx_name , domain , service , callback ):
469
+ def service_register (cls , global_ctx_name , domain , service , callback , supports_response = SupportsResponse . NONE ):
454
470
"""Register a new service callback."""
455
471
key = f"{ domain } .{ service } "
456
472
if key not in cls .service_cnt :
@@ -462,7 +478,7 @@ def service_register(cls, global_ctx_name, domain, service, callback):
462
478
f"{ global_ctx_name } : can't register service { key } ; already defined in { cls .service2global_ctx [key ]} "
463
479
)
464
480
cls .service_cnt [key ] += 1
465
- cls .hass .services .async_register (domain , service , callback )
481
+ cls .hass .services .async_register (domain , service , callback , supports_response = supports_response )
466
482
467
483
@classmethod
468
484
def service_remove (cls , global_ctx_name , domain , service ):
0 commit comments