diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 59ceacc..2d34021 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,12 +3,12 @@ name: CI # yamllint disable-line rule:truthy on: push: - pull_request: ~ env: - CACHE_VERSION: 1 - DEFAULT_PYTHON: 3.8 - PRE_COMMIT_HOME: ~/.cache/pre-commit + CODE_FOLDER: zigpy_zboss + CACHE_VERSION: 3 + DEFAULT_PYTHON: 3.10.8 + PRE_COMMIT_CACHE_PATH: ~/.cache/pre-commit jobs: # Separate job to pre-populate the base dependency cache @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, "3.10", "3.11"] + python-version: ["3.10.8", "3.11.0", "3.12"] steps: - name: Check out code from GitHub uses: actions/checkout@v2 @@ -44,6 +44,7 @@ jobs: python -m venv venv . venv/bin/activate pip install -U pip setuptools pre-commit + pip install -r requirements_test.txt pip install -e . pre-commit: @@ -76,7 +77,7 @@ jobs: id: cache-precommit uses: actions/cache@v2 with: - path: ${{ env.PRE_COMMIT_HOME }} + path: ${{ env.PRE_COMMIT_CACHE_PATH }} key: | ${{ env.CACHE_VERSION}}-${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} restore-keys: | @@ -86,6 +87,15 @@ jobs: run: | . venv/bin/activate pre-commit install-hooks + - name: Cache pre-commit environment + uses: actions/cache/save@v3 + with: + path: ${{ env.PRE_COMMIT_CACHE_PATH }} + key: ${{ steps.cache-precommit.outputs.cache-primary-key }} + - name: Lint and static analysis + run: | + . venv/bin/activate + pre-commit run --show-diff-on-failure --color=always --all-files lint-flake8: name: Check flake8 @@ -117,7 +127,7 @@ jobs: id: cache-precommit uses: actions/cache@v2 with: - path: ${{ env.PRE_COMMIT_HOME }} + path: ${{ env.PRE_COMMIT_CACHE_PATH }} key: | ${{ env.CACHE_VERSION}}-${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} - name: Fail job if cache restore failed @@ -132,3 +142,54 @@ jobs: run: | . venv/bin/activate pre-commit run --hook-stage manual flake8 --all-files + + pytest: + runs-on: ubuntu-latest + needs: prepare-base + strategy: + matrix: + python-version: ["3.10.8", "3.11.0", "3.12"] + name: >- + Run tests Python ${{ matrix.python-version }} + steps: + - name: Check out code from GitHub + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + id: python + with: + python-version: ${{ matrix.python-version }} + - name: Restore base Python virtual environment + id: cache-venv + uses: actions/cache@v2 + with: + path: venv + key: >- + ${{ env.CACHE_VERSION}}-${{ runner.os }}-base-venv-${{ + steps.python.outputs.python-version }}-${{ + hashFiles('pyproject.toml') }} + - name: Fail job if Python cache restore failed + if: steps.cache-venv.outputs.cache-hit != 'true' + run: | + echo "Failed to restore Python virtual environment from cache" + exit 1 + - name: Register Python problem matcher + run: | + echo "::add-matcher::.github/workflows/matchers/python.json" + - name: Install Pytest Annotation plugin + run: | + . venv/bin/activate + # Ideally this should be part of our dependencies + # However this plugin is fairly new and doesn't run correctly + # on a non-GitHub environment. + pip install pytest-github-actions-annotate-failures + - name: Run pytest + run: | + . venv/bin/activate + pytest \ + -qq \ + --timeout=20 \ + --durations=10 \ + -o console_output_style=count \ + -p no:sugar \ + tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4546342..1829a4a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,15 @@ repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: debug-statements + - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 5.0.4 hooks: - id: flake8 - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - - id: debug-statements - - id: no-commit-to-branch - args: - - --branch=dev - - --branch=main - - --branch=rc \ No newline at end of file + - id: isort diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000..85c2944 --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1,6 @@ +pytest>=7.3.1 +pytest-asyncio>=0.21.0 +pytest-timeout>=2.1.0 +pytest-mock>=3.10.0 +pytest-cov>=4.1.0 +flake8==5.0.4 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..b4c17f5 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for zigpy-zboss.""" diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..ea431d6 --- /dev/null +++ b/tests/api/__init__.py @@ -0,0 +1 @@ +"""Tests for api.""" diff --git a/tests/api/test_connect.py b/tests/api/test_connect.py new file mode 100644 index 0000000..dcb981f --- /dev/null +++ b/tests/api/test_connect.py @@ -0,0 +1,73 @@ +"""Test cases for zigpy-zboss API connect/close methods.""" +import pytest + +from zigpy_zboss.api import ZBOSS + +from ..conftest import BaseServerZBOSS, config_for_port_path + + +@pytest.mark.asyncio +async def test_connect_no_test(make_zboss_server): + """Test that ZBOSS.connect() can connect.""" + zboss_server = make_zboss_server(server_cls=BaseServerZBOSS) + zboss = ZBOSS(config_for_port_path(zboss_server.port_path)) + + await zboss.connect() + + # Nothing will be sent + assert zboss_server._uart.data_received.call_count == 0 + + zboss.close() + + +@pytest.mark.asyncio +async def test_api_close(connected_zboss, mocker): + """Test that ZBOSS.close() properly cleans up the object.""" + zboss, zboss_server = connected_zboss + uart = zboss._uart + mocker.spy(uart, "close") + + # add some dummy listeners, should be cleared on close + zboss._listeners = { + 'listener1': [mocker.Mock()], 'listener2': [mocker.Mock()] + } + + zboss.close() + + # Make sure our UART was actually closed + assert zboss._uart is None + assert zboss._app is None + assert uart.close.call_count == 1 + + # ZBOSS.close should not throw any errors if called multiple times + zboss.close() + zboss.close() + + def dict_minus(d, minus): + return {k: v for k, v in d.items() if k not in minus} + + ignored_keys = [ + "_blocking_request_lock", + "_reset_uart_reconnect", + "_disconnected_event", + "nvram", + "version" + ] + + # Closing ZBOSS should reset it completely to that of a fresh object + # We have to ignore our mocked method and the lock + zboss2 = ZBOSS(zboss._config) + assert ( + zboss2._blocking_request_lock.locked() + == zboss._blocking_request_lock.locked() + ) + assert dict_minus(zboss.__dict__, ignored_keys) == dict_minus( + zboss2.__dict__, ignored_keys + ) + + zboss2.close() + zboss2.close() + + assert dict_minus(zboss.__dict__, ignored_keys) == dict_minus( + zboss2.__dict__, ignored_keys + ) diff --git a/tests/api/test_listeners.py b/tests/api/test_listeners.py new file mode 100644 index 0000000..f36e3d8 --- /dev/null +++ b/tests/api/test_listeners.py @@ -0,0 +1,187 @@ +"""Test listeners.""" +import asyncio +from unittest.mock import call + +import pytest + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t +from zigpy_zboss.api import IndicationListener, OneShotResponseListener + + +@pytest.mark.asyncio +async def test_resolve(event_loop, mocker): + """Test listener resolution.""" + callback = mocker.Mock() + callback_listener = IndicationListener( + [c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + )], callback + ) + + future = event_loop.create_future() + one_shot_listener = OneShotResponseListener([c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + )], future) + + match = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + no_match = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3, + ) + + assert callback_listener.resolve(match) + assert not callback_listener.resolve(no_match) + assert callback_listener.resolve(match) + assert not callback_listener.resolve(no_match) + + assert one_shot_listener.resolve(match) + assert not one_shot_listener.resolve(no_match) + + callback.assert_has_calls([call(match), call(match)]) + assert callback.call_count == 2 + + assert (await future) == match + + # Cancelling a callback will have no effect + assert not callback_listener.cancel() + + # Cancelling a one-shot listener does not throw any errors + assert one_shot_listener.cancel() + assert one_shot_listener.cancel() + assert one_shot_listener.cancel() + + +@pytest.mark.asyncio +async def test_cancel(event_loop): + """Test cancelling one-shot listener.""" + # Cancelling a one-shot listener prevents it from being fired + future = event_loop.create_future() + one_shot_listener = OneShotResponseListener([c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )], future) + one_shot_listener.cancel() + + match = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + assert not one_shot_listener.resolve(match) + + with pytest.raises(asyncio.CancelledError): + await future + + +@pytest.mark.asyncio +async def test_multi_cancel(event_loop, mocker): + """Test cancelling indication listener.""" + callback = mocker.Mock() + callback_listener = IndicationListener( + [c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )], callback + ) + + future = event_loop.create_future() + one_shot_listener = OneShotResponseListener([c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )], future) + + match = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + no_match = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3, + ) + + assert callback_listener.resolve(match) + assert not callback_listener.resolve(no_match) + + assert one_shot_listener.resolve(match) + assert not one_shot_listener.resolve(no_match) + + callback.assert_called_once_with(match) + assert (await future) == match + + +@pytest.mark.asyncio +async def test_api_cancel_listeners(connected_zboss, mocker): + """Test cancel listeners from api.""" + zboss, zboss_server = connected_zboss + + callback = mocker.Mock() + + zboss.register_indication_listener( + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), callback + ) + future = zboss.wait_for_responses( + [ + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3, + ), + ] + ) + + assert not future.done() + zboss.close() + + with pytest.raises(asyncio.CancelledError): + await future + + # add_done_callback won't be executed immediately + await asyncio.sleep(0.1) + + # only one shot listerner is cleared + # we do not remove indication listeners + # because + assert len(zboss._listeners) == 0 diff --git a/tests/api/test_request.py b/tests/api/test_request.py new file mode 100644 index 0000000..53a6344 --- /dev/null +++ b/tests/api/test_request.py @@ -0,0 +1,194 @@ +"""Test api requests.""" +import asyncio +import logging + +import async_timeout +import pytest + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t +from zigpy_zboss.frames import (ZBNCP_LL_BODY_SIZE_MAX, Frame, HLPacket, + LLHeader) + + +@pytest.mark.asyncio +async def test_cleanup_timeout_internal(connected_zboss): + """Test internal cleanup timeout.""" + zboss, zboss_server = connected_zboss + + assert not any(zboss._listeners.values()) + + with pytest.raises(asyncio.TimeoutError): + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSN=1), 0.1) + + # We should be cleaned up + assert not any(zboss._listeners.values()) + + +@pytest.mark.asyncio +async def test_cleanup_timeout_external(connected_zboss): + """Test external cleanup timeout.""" + zboss, zboss_server = connected_zboss + + assert not any(zboss._listeners.values()) + + # This request will timeout because we didn't send anything back + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(0.1): + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSN=1), 10) + + # We should be cleaned up + assert not any(zboss._listeners.values()) + + +@pytest.mark.asyncio +async def test_zboss_request_kwargs(connected_zboss, event_loop): + """Test zboss request.""" + zboss, zboss_server = connected_zboss + + # Invalid format + with pytest.raises(KeyError): + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSNT=1), 10) + + # Valid format, invalid name + with pytest.raises(KeyError): + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TsN=1), 10) + + # Valid format, valid name + ping_rsp = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ) + + async def send_ping_response(): + await zboss_server.send(ping_rsp) + + event_loop.call_soon(asyncio.create_task, send_ping_response()) + + assert ( + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSN=1), 2) + ) == ping_rsp + + # You cannot send anything but requests + with pytest.raises(ValueError): + await zboss.request(c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + )) + + # You cannot send indications + with pytest.raises(ValueError): + await zboss.request( + c.NWK.NwkLeaveInd.Ind(partial=True) + ) + + +@pytest.mark.asyncio +async def test_zboss_req_rsp(connected_zboss, event_loop): + """Test zboss request/response.""" + zboss, zboss_server = connected_zboss + + # Each SREQ must have a corresponding SRSP, so this will fail + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(0.5): + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSN=1), 10) + + # This will work + ping_rsp = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ) + + async def send_ping_response(): + await zboss_server.send(ping_rsp) + + event_loop.call_soon(asyncio.create_task, send_ping_response()) + + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSN=1), 10) + + +@pytest.mark.asyncio +async def test_zboss_unknown_frame(connected_zboss, caplog): + """Test zboss unknown frame.""" + zboss, _ = connected_zboss + hl_header = t.HLCommonHeader( + version=0x0121, type=0xFFFF, id=0x123421 + ) + hl_packet = HLPacket(header=hl_header, data=t.Bytes()) + ll_header = LLHeader(flags=0xC0, size=0x0A) + frame = Frame(ll_header=ll_header, hl_packet=hl_packet) + + caplog.set_level(logging.DEBUG) + zboss.frame_received(frame) + + # Unknown frames are logged in their entirety but an error is not thrown + assert repr(frame) in caplog.text + + +@pytest.mark.asyncio +async def test_send_failure_when_disconnected(connected_zboss): + """Test send failure when disconnected.""" + zboss, _ = connected_zboss + zboss._uart = None + + with pytest.raises(RuntimeError) as e: + await zboss.request(c.NcpConfig.GetModuleVersion.Req(TSN=1), 10) + + assert "Coordinator is disconnected" in str(e.value) + zboss.close() + + +@pytest.mark.asyncio +async def test_frame_merge(connected_zboss, mocker): + """Test frame fragmentation.""" + zboss, zboss_server = connected_zboss + + large_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX * 2 + 50) + command = c.NcpConfig.ReadNVRAM.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + NVRAMVersion=t.uint16_t(0x0000), + DatasetId=t.DatasetId(0x0000), + Dataset=t.NVRAMDataset(large_data), + DatasetVersion=t.uint16_t(0x0000) + ) + frame = command.to_frame() + + callback = mocker.Mock() + + zboss.register_indication_listener( + c.NcpConfig.ReadNVRAM.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + NVRAMVersion=t.uint16_t(0x0000), + DatasetId=t.DatasetId(0x0000), + Dataset=t.NVRAMDataset(large_data), + DatasetVersion=t.uint16_t(0x0000) + ), callback + ) + + # Perform fragmentation + fragments = frame.handle_tx_fragmentation() + assert len(fragments) > 1 + + # Receiving first and middle fragments + for fragment in fragments[:-1]: + assert not zboss.frame_received(fragment) + + # receiving the last fragment + assert zboss.frame_received(fragments[-1]) + + # Check the state of _rx_fragments after merging + assert zboss._rx_fragments == [] diff --git a/tests/api/test_response.py b/tests/api/test_response.py new file mode 100644 index 0000000..c810f53 --- /dev/null +++ b/tests/api/test_response.py @@ -0,0 +1,568 @@ +"""Test response.""" +import asyncio + +import async_timeout +import pytest + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t +from zigpy_zboss.utils import deduplicate_commands + + +@pytest.mark.asyncio +async def test_responses(connected_zboss): + """Test responses.""" + zboss, zboss_server = connected_zboss + + assert not any(zboss._listeners.values()) + + future = zboss.wait_for_response( + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )) + + assert any(zboss._listeners.values()) + + response = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + await zboss_server.send(response) + + assert (await future) == response + + # Our listener will have been cleaned up after a step + await asyncio.sleep(0) + assert not any(zboss._listeners.values()) + + +@pytest.mark.asyncio +async def test_responses_multiple(connected_zboss): + """Test multiple responses.""" + zboss, _ = connected_zboss + + assert not any(zboss._listeners.values()) + + future1 = zboss.wait_for_response(c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )) + future2 = zboss.wait_for_response(c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )) + future3 = zboss.wait_for_response(c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )) + + response = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + zboss.frame_received(response.to_frame()) + + await future1 + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert not future2.done() + assert not future3.done() + + assert any(zboss._listeners.values()) + + +@pytest.mark.asyncio +async def test_response_timeouts(connected_zboss): + """Test future response timeouts.""" + zboss, _ = connected_zboss + + response = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + + async def send_soon(delay): + await asyncio.sleep(delay) + zboss.frame_received(response.to_frame()) + + asyncio.create_task(send_soon(0.1)) + + async with async_timeout.timeout(0.5): + assert (await zboss.wait_for_response(c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ))) == response + + # The response was successfully received so we + # should have no outstanding listeners + await asyncio.sleep(0) + assert not any(zboss._listeners.values()) + + asyncio.create_task(send_soon(0.6)) + + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(0.5): + assert ( + await zboss.wait_for_response( + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )) + ) == response + + # Our future still completed, albeit unsuccessfully. + # We should have no leaked listeners here. + assert not any(zboss._listeners.values()) + + +@pytest.mark.asyncio +async def test_response_matching_partial(connected_zboss): + """Test partial response matching.""" + zboss, _ = connected_zboss + + future = zboss.wait_for_response( + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ) + ) + + response1 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + response2 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + response3 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + + zboss.frame_received(response1.to_frame()) + zboss.frame_received(response2.to_frame()) + zboss.frame_received(response3.to_frame()) + + assert future.done() + assert (await future) == response2 + + +@pytest.mark.asyncio +async def test_response_matching_exact(connected_zboss): + """Test exact response matching.""" + zboss, _ = connected_zboss + + response1 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + response2 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(2) + ) + response3 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=11, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + + future = zboss.wait_for_response(response2) + + zboss.frame_received(response1.to_frame()) + zboss.frame_received(response2.to_frame()) + zboss.frame_received(response3.to_frame()) + + # Future should be immediately resolved + assert future.done() + assert (await future) == response2 + + +@pytest.mark.asyncio +async def test_response_not_matching_out_of_order(connected_zboss): + """Test not matching response.""" + zboss, _ = connected_zboss + + response = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + zboss.frame_received(response.to_frame()) + + future = zboss.wait_for_response(response) + + # This future will never resolve because we were not + # expecting a response and discarded it + assert not future.done() + + +@pytest.mark.asyncio +async def test_wait_responses_empty(connected_zboss): + """Test wait empty response.""" + zboss, _ = connected_zboss + + # You shouldn't be able to wait for an empty list of responses + with pytest.raises(ValueError): + await zboss.wait_for_responses([]) + + +@pytest.mark.asyncio +async def test_response_callback_simple(connected_zboss, event_loop, mocker): + """Test simple response callback.""" + zboss, _ = connected_zboss + + sync_callback = mocker.Mock() + + good_response = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + bad_response = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.ERROR, + DeviceRole=t.DeviceRole(1) + ) + + zboss.register_indication_listener(good_response, sync_callback) + + zboss.frame_received(bad_response.to_frame()) + assert sync_callback.call_count == 0 + + zboss.frame_received(good_response.to_frame()) + sync_callback.assert_called_once_with(good_response) + + +@pytest.mark.asyncio +async def test_response_callbacks(connected_zboss, event_loop, mocker): + """Test response callbacks.""" + zboss, _ = connected_zboss + + sync_callback = mocker.Mock() + bad_sync_callback = mocker.Mock( + side_effect=RuntimeError + ) # Exceptions should not interfere with other callbacks + + async_callback_responses = [] + + # XXX: I can't get AsyncMock().call_count to work, even though + # the callback is definitely being called + async def async_callback(response): + await asyncio.sleep(0) + async_callback_responses.append(response) + + good_response1 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + good_response2 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(2) + ) + good_response3 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ) + bad_response1 = c.ZDO.MgtLeave.Rsp(TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK) + bad_response2 = c.NcpConfig.GetModuleVersion.Req(TSN=1) + + responses = [ + # Duplicating matching responses shouldn't do anything + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ), + # Matching against different response types should also work + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=4 + ), + ] + + assert set(deduplicate_commands(responses)) == { + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ), + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ), + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=4 + ), + } + + # We shouldn't see any effects from receiving a frame early + zboss.frame_received(good_response1.to_frame()) + + for callback in [bad_sync_callback, async_callback, sync_callback]: + zboss.register_indication_listeners(responses, callback) + + zboss.frame_received(good_response1.to_frame()) + zboss.frame_received(bad_response1.to_frame()) + zboss.frame_received(good_response2.to_frame()) + zboss.frame_received(bad_response2.to_frame()) + zboss.frame_received(good_response3.to_frame()) + + await asyncio.sleep(0) + + assert sync_callback.call_count == 3 + assert bad_sync_callback.call_count == 3 + + await asyncio.sleep(0.1) + # assert async_callback.call_count == 3 # XXX: this always returns zero + assert len(async_callback_responses) == 3 + + +@pytest.mark.asyncio +async def test_wait_for_responses(connected_zboss, event_loop): + """Test wait for responses.""" + zboss, _ = connected_zboss + + response1 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + response2 = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(2) + ) + response3 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ) + response4 = c.ZDO.MgtLeave.Rsp(TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK) + response5 = c.NcpConfig.GetModuleVersion.Req(TSN=1) + + # We shouldn't see any effects from receiving a frame early + zboss.frame_received(response1.to_frame()) + + # Will match the first response1 and detach + future1 = zboss.wait_for_responses( + [c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ), c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + )] + ) + + # Will match the first response3 and detach + future2 = zboss.wait_for_responses( + [ + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(10) + ), + ] + ) + + # Will not match anything + future3 = zboss.wait_for_responses([c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=4 + )]) + + # Will match response1 the second time around + future4 = zboss.wait_for_responses( + [ + # Matching against different response types should also work + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3 + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=4 + ), + ] + ) + + zboss.frame_received(response1.to_frame()) + zboss.frame_received(response2.to_frame()) + zboss.frame_received(response3.to_frame()) + zboss.frame_received(response4.to_frame()) + zboss.frame_received(response5.to_frame()) + + assert future1.done() + assert future2.done() + assert not future3.done() + assert not future4.done() + + await asyncio.sleep(0) + + zboss.frame_received(response1.to_frame()) + zboss.frame_received(response2.to_frame()) + zboss.frame_received(response3.to_frame()) + zboss.frame_received(response4.to_frame()) + zboss.frame_received(response5.to_frame()) + + assert future1.done() + assert future2.done() + assert not future3.done() + assert future4.done() + + assert (await future1) == response1 + assert (await future2) == response3 + assert (await future4) == response1 + + await asyncio.sleep(0) + + zboss.frame_received(c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=4 + ).to_frame()) + assert future3.done() + assert (await future3) == c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=4 + ) diff --git a/tests/application/__init__.py b/tests/application/__init__.py new file mode 100644 index 0000000..944cf09 --- /dev/null +++ b/tests/application/__init__.py @@ -0,0 +1 @@ +"""Tests for application.""" diff --git a/tests/application/test_connect.py b/tests/application/test_connect.py new file mode 100644 index 0000000..e4bdffe --- /dev/null +++ b/tests/application/test_connect.py @@ -0,0 +1,188 @@ +"""Test application connect.""" +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +import zigpy_zboss.commands as c +import zigpy_zboss.config as conf +import zigpy_zboss.types as t +from zigpy_zboss.uart import connect as uart_connect +from zigpy_zboss.zigbee.application import ControllerApplication + +from ..conftest import BaseServerZBOSS, BaseZbossDevice + + +@pytest.mark.asyncio +async def test_no_double_connect(make_zboss_server, mocker): + """Test no multiple connection.""" + zboss_server = make_zboss_server(server_cls=BaseServerZBOSS) + + app = mocker.Mock() + await uart_connect( + conf.SCHEMA_DEVICE( + {conf.CONF_DEVICE_PATH: zboss_server.serial_port} + ), app + ) + + with pytest.raises(RuntimeError): + await uart_connect( + conf.SCHEMA_DEVICE( + {conf.CONF_DEVICE_PATH: zboss_server.serial_port}), app + ) + + +@pytest.mark.asyncio +async def test_leak_detection(make_zboss_server, mocker): + """Test leak detection.""" + zboss_server = make_zboss_server(server_cls=BaseServerZBOSS) + + def count_connected(): + return sum(t._is_connected for t in zboss_server._transports) + + # Opening and closing one connection will keep the count at zero + assert count_connected() == 0 + app = mocker.Mock() + protocol1 = await uart_connect( + conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: zboss_server.serial_port}), + app + ) + assert count_connected() == 1 + protocol1.close() + assert count_connected() == 0 + + # Once more for good measure + protocol2 = await uart_connect( + conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: zboss_server.serial_port}), + app + ) + assert count_connected() == 1 + protocol2.close() + assert count_connected() == 0 + + +@pytest.mark.asyncio +async def test_probe_unsuccessful_slow(make_zboss_server, mocker): + """Test unsuccessful probe.""" + zboss_server = make_zboss_server( + server_cls=BaseServerZBOSS, shorten_delays=False + ) + + # Don't respond to anything + zboss_server._listeners.clear() + + mocker.patch("zigpy_zboss.zigbee.application.PROBE_TIMEOUT", new=0.1) + + assert not ( + await ControllerApplication.probe( + conf.SCHEMA_DEVICE( + {conf.CONF_DEVICE_PATH: zboss_server.serial_port} + ) + ) + ) + + assert not any([t._is_connected for t in zboss_server._transports]) + + +@pytest.mark.asyncio +async def test_probe_successful(make_zboss_server, event_loop): + """Test successful probe.""" + zboss_server = make_zboss_server( + server_cls=BaseServerZBOSS, shorten_delays=False + ) + + # This will work + ping_rsp = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1), + ) + + async def send_ping_response(): + await zboss_server.send(ping_rsp) + + event_loop.call_soon(asyncio.create_task, send_ping_response()) + + assert await ControllerApplication.probe( + conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: zboss_server.serial_port}) + ) + assert not any([t._is_connected for t in zboss_server._transports]) + + +@pytest.mark.asyncio +async def test_probe_multiple(make_application): + """Test multiple probe.""" + # Make sure that our listeners don't get cleaned up after each probe + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + app.close = lambda: None + + config = conf.SCHEMA_DEVICE( + {conf.CONF_DEVICE_PATH: zboss_server.serial_port} + ) + + assert await app.probe(config) + assert await app.probe(config) + assert await app.probe(config) + assert await app.probe(config) + + assert not any([t._is_connected for t in zboss_server._transports]) + + +@pytest.mark.asyncio +async def test_shutdown_from_app(mocker, make_application, event_loop): + """Test shutdown from application.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + await app.startup(auto_form=False) + + # It gets deleted but we save a reference to it + transport = app._api._uart._transport + mocker.spy(transport, "close") + + # Close the connection application-side + await app.shutdown() + + # And the serial connection should have been closed + assert transport.close.call_count >= 1 + + +@pytest.mark.asyncio +async def test_clean_shutdown(make_application): + """Test clean shutdown.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + # This should not throw + await app.shutdown() + + assert app._api is None + + +@pytest.mark.asyncio +async def test_multiple_shutdown(make_application): + """Test multiple shutdown.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + await app.shutdown() + await app.shutdown() + await app.shutdown() + + +@pytest.mark.asyncio +@patch( + "zigpy_zboss.zigbee.application.ControllerApplication._watchdog_period", + new=0.1 +) +async def test_watchdog(make_application): + """Test the watchdog.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + app._watchdog_feed = AsyncMock(wraps=app._watchdog_feed) + + await app.startup(auto_form=False) + await asyncio.sleep(0.6) + assert len(app._watchdog_feed.mock_calls) >= 5 + + await app.shutdown() diff --git a/tests/application/test_join.py b/tests/application/test_join.py new file mode 100644 index 0000000..26d0020 --- /dev/null +++ b/tests/application/test_join.py @@ -0,0 +1,134 @@ +"""Test application device joining.""" +import asyncio + +import pytest +import zigpy.device +import zigpy.types +import zigpy.util + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t + +from ..conftest import BaseZbossDevice + + +@pytest.mark.asyncio +async def test_permit_join(mocker, make_application): + """Test permit join.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + permit_join_coordinator = zboss_server.reply_once_to( + request=c.ZDO.PermitJoin.Req( + TSN=123, + DestNWK=t.NWK(0x0000), + PermitDuration=t.uint8_t(10), + TCSignificance=t.uint8_t(0x01), + ), + responses=[ + c.ZDO.PermitJoin.Rsp( + TSN=123, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ), + ], + ) + + await app.startup(auto_form=False) + await app.permit(time_s=10) + + await asyncio.sleep(0.1) + + assert permit_join_coordinator.done() + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_join_coordinator(make_application): + """Test coordinator join.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + # Handle us opening joins on the coordinator + permit_join_coordinator = zboss_server.reply_once_to( + request=c.ZDO.PermitJoin.Req( + TSN=123, + DestNWK=t.NWK(0x0000), + PermitDuration=t.uint8_t(60), + TCSignificance=t.uint8_t(0x01), + partial=True + ), + responses=[ + c.ZDO.PermitJoin.Rsp( + TSN=123, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ), + ], + ) + + await app.startup(auto_form=False) + await app.permit(node=app.state.node_info.ieee) + + await permit_join_coordinator + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_join_device(make_application): + """Test device join.""" + ieee = t.EUI64.convert("EC:1B:BD:FF:FE:54:4F:40") + nwk = 0x1234 + + app, zboss_server = make_application(server_cls=BaseZbossDevice) + app.add_initialized_device(ieee=ieee, nwk=nwk) + + permit_join = zboss_server.reply_once_to( + request=c.ZDO.PermitJoin.Req( + TSN=123, + DestNWK=t.NWK(zigpy.types.t.BroadcastAddress.RX_ON_WHEN_IDLE), + PermitDuration=t.uint8_t(60), + TCSignificance=t.uint8_t(0), + ), + responses=[ + c.ZDO.PermitJoin.Rsp( + TSN=123, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ) + ], + ) + + await app.startup(auto_form=False) + await app.permit(node=ieee) + + await permit_join + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_on_zdo_device_join(make_application, mocker): + """Test ZDO device join indication listener.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.patch.object(app, "handle_join", wraps=app.handle_join) + + nwk = 0x1234 + ieee = t.EUI64.convert("11:22:33:44:55:66:77:88") + + await zboss_server.send(c.ZDO.DevAnnceInd.Ind( + NWK=nwk, + IEEE=ieee, + MacCap=t.uint8_t(0x01) + ) + ) + + await asyncio.sleep(0.1) + + app.handle_join.assert_called_once_with( + nwk=nwk, ieee=ieee, parent_nwk=None + ) + + await app.shutdown() diff --git a/tests/application/test_requests.py b/tests/application/test_requests.py new file mode 100644 index 0000000..eb80a52 --- /dev/null +++ b/tests/application/test_requests.py @@ -0,0 +1,559 @@ +"""Test application requests.""" +import asyncio +from unittest.mock import AsyncMock as CoroutineMock + +import pytest +import zigpy.endpoint +import zigpy.profiles +import zigpy.types as zigpy_t + +import zigpy_zboss.commands as c +import zigpy_zboss.config as conf +import zigpy_zboss.types as t + +from ..conftest import BaseZbossDevice + + +@pytest.mark.asyncio +async def test_zigpy_request(make_application): + """Test zigpy request.""" + app, zboss_server = make_application(BaseZbossDevice) + await app.startup(auto_form=False) + + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) + + ep = device.add_endpoint(1) + ep.status = zigpy.endpoint.Status.ZDO_INIT + ep.profile_id = 260 + ep.add_input_cluster(6) + + # Construct the payload with the correct FrameControl byte + # FrameControl bits: 0001 0000 -> 0x10 for Server_to_Client + frame_control_byte = 0x18 + tsn = 0x01 + command_id = 0x01 + + payload = [frame_control_byte, tsn, command_id] + payload_length = len(payload) + # Respond to a light turn on request + zboss_server.reply_once_to( + request=c.APS.DataReq.Req( + TSN=1, ParamLength=21, DataLength=3, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + ProfileID=260, ClusterId=6, DstEndpoint=1, SrcEndpoint=1, Radius=0, + DstAddrMode=zigpy_t.AddrMode.NWK, + TxOptions=c.aps.TransmitOptions.NONE, + UseAlias=t.Bool.false, AliasSrcAddr=0x0000, AliasSeqNbr=0, + Payload=[1, 1, 1]), + responses=[c.APS.DataReq.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + DstEndpoint=1, + SrcEndpoint=1, + TxTime=1, + DstAddrMode=zigpy_t.AddrMode.NWK + ), + c.APS.DataIndication.Ind( + ParamLength=21, + PayloadLength=payload_length, + FrameFC=t.APSFrameFC(0x01), + SrcAddr=t.NWK(0xAABB), + DstAddr=t.NWK(0x1234), + GrpAddr=t.NWK(0x5678), + DstEndpoint=1, + SrcEndpoint=1, + ClusterId=6, + ProfileId=260, + PacketCounter=10, + SrcMACAddr=t.NWK(0xAABB), + DstMACAddr=t.NWK(0x1234), + LQI=255, + RSSI=-70, + KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(payload) + )], + ) + + # Turn on the light + await device.endpoints[1].on_off.on() + + await app.shutdown() + + +# @pytest.mark.parametrize("device", FORMED_DEVICES) +# async def test_zigpy_request_failure(device, make_application, mocker): +# app, zboss_server = make_application(device) +# await app.startup(auto_form=False) +# +# TSN = 1 +# +# device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) +# +# ep = device.add_endpoint(1) +# ep.profile_id = 260 +# ep.add_input_cluster(6) +# +# # Fail to respond to a light turn on request +# zboss_server.reply_to( +# request=c.AF.DataRequestExt.Req( +# DstAddrModeAddress=t.AddrModeAddress( +# mode=t.AddrMode.NWK, address=device.nwk +# ), +# DstEndpoint=1, +# SrcEndpoint=1, +# ClusterId=6, +# TSN=TSN, +# Data=bytes([0x01, TSN, 0x01]), +# partial=True, +# ), +# responses=[ +# c.AF.DataRequestExt.Rsp(Status=t.Status.SUCCESS), +# c.AF.DataConfirm.Callback( +# Status=t.Status.FAILURE, +# Endpoint=1, +# TSN=TSN, +# ), +# ], +# ) +# +# mocker.spy(app, "send_packet") +# +# # Fail to turn on the light +# with pytest.raises(InvalidCommandResponse): +# await device.endpoints[1].on_off.on() +# +# assert app.send_packet.call_count == 1 +# await app.shutdown() +# +# + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "addr", + [ + zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.IEEE, address=t.EUI64(range(8))), + zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, address=t.NWK(0xAABB)), + ], +) +async def test_request_addr_mode(addr, make_application, mocker): + """Test address mode request.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + await app.startup(auto_form=False) + + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) + + mocker.patch.object(app, "send_packet", new=CoroutineMock()) + + await app.request( + device, + use_ieee=(addr.addr_mode == zigpy.types.AddrMode.IEEE), + profile=1, + cluster=2, + src_ep=3, + dst_ep=4, + sequence=5, + data=b"6", + ) + + assert app.send_packet.call_count == 1 + assert app.send_packet.mock_calls[0].args[0].dst == addr + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_mrequest(make_application, mocker): + """Test multicast request.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + mocker.patch.object(app, "send_packet", new=CoroutineMock()) + group = app.groups.add_group(0x1234, "test group") + + await group.endpoint.on_off.on() + + assert app.send_packet.call_count == 1 + assert ( + app.send_packet.mock_calls[0].args[0].dst == + zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, address=0x1234) + ) + assert app.send_packet.mock_calls[0].args[0].data.serialize() == \ + b"\x01\x01\x01" + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_mrequest_doesnt_block(make_application, event_loop): + """Test non blocking multicast request.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + zboss_server.reply_once_to( + request=c.APS.DataReq.Req( + TSN=1, ParamLength=21, DataLength=3, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:12:34"), + ProfileID=260, ClusterId=6, DstEndpoint=0, SrcEndpoint=1, Radius=0, + DstAddrMode=zigpy_t.AddrMode.Group, + TxOptions=c.aps.TransmitOptions.NONE, + UseAlias=t.Bool.false, AliasSrcAddr=0x0000, AliasSeqNbr=0, + Payload=[1, 1, 1]), + responses=[ + c.APS.DataReq.Rsp( + TSN=1, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + DstEndpoint=1, + SrcEndpoint=1, + TxTime=1, + DstAddrMode=zigpy_t.AddrMode.Group, + ), + ], + ) + + data_confirm_rsp = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=None, FrameFC=None, + SrcAddr=None, DstAddr=None, GrpAddr=None, DstEndpoint=1, + SrcEndpoint=1, ClusterId=6, ProfileId=260, + PacketCounter=None, SrcMACAddr=None, DstMACAddr=None, + LQI=None, RSSI=None, KeySrcAndAttr=None, Payload=None, partial=True + ) + + request_sent = event_loop.create_future() + + async def on_request_sent(): + await zboss_server.send(data_confirm_rsp) + + request_sent.add_done_callback( + lambda _: event_loop.create_task(on_request_sent()) + ) + + await app.startup(auto_form=False) + + group = app.groups.add_group(0x1234, "test group") + await group.endpoint.on_off.on() + request_sent.set_result(True) + + await asyncio.sleep(0.01) + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_broadcast(make_application, mocker): + """Test broadcast request.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup() + zboss_server.reply_once_to( + request=c.APS.DataReq.Req(TSN=1, ParamLength=21, DataLength=3, + DstAddr=t.EUI64.convert( + "00:00:00:00:00:00:ff:fd"), + ProfileID=260, ClusterId=3, DstEndpoint=255, + SrcEndpoint=1, Radius=3, + DstAddrMode=zigpy_t.AddrMode.Group, + TxOptions=c.aps.TransmitOptions.NONE, + UseAlias=t.Bool.false, + AliasSrcAddr=0x0000, AliasSeqNbr=0, + Payload=[63, 63, 63]), + responses=[ + c.APS.DataReq.Rsp( + TSN=1, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:ff:fd"), + DstEndpoint=255, + SrcEndpoint=1, + TxTime=1, + DstAddrMode=zigpy_t.AddrMode.Group, + ), + ], + ) + + await app.broadcast( + profile=260, # ZHA + cluster=0x0003, # Identify + src_ep=1, + dst_ep=0xFF, # Any endpoint + grpid=0, + radius=3, + sequence=1, + data=b"???", + broadcast_address=zigpy_t.BroadcastAddress.RX_ON_WHEN_IDLE, + ) + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_request_concurrency(make_application, mocker): + """Test request concurency.""" + app, zboss_server = make_application( + server_cls=BaseZbossDevice, + client_config={conf.CONF_MAX_CONCURRENT_REQUESTS: 2}, + ) + + await app.startup() + + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) + + ep = device.add_endpoint(1) + ep.status = zigpy.endpoint.Status.ZDO_INIT + ep.profile_id = 260 + ep.add_input_cluster(6) + + # Keep track of how many requests we receive at once + in_flight_requests = 0 + did_lock = False + + def make_response(req): + async def callback(req): + nonlocal in_flight_requests + nonlocal did_lock + + if app._concurrent_requests_semaphore.locked(): + did_lock = True + + in_flight_requests += 1 + assert in_flight_requests <= 2 + + await asyncio.sleep(0.1) + await zboss_server.send(c.APS.DataReq.Rsp( + TSN=req.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=req.DstAddr, + DstEndpoint=req.DstEndpoint, + SrcEndpoint=req.SrcEndpoint, + TxTime=1, + DstAddrMode=req.DstAddrMode, + )) + await asyncio.sleep(0) + + in_flight_requests -= 1 + assert in_flight_requests >= 0 + + asyncio.create_task(callback(req)) + + zboss_server.reply_to( + request=c.APS.DataReq.Req( + partial=True), responses=[make_response] + + ) + + # We create a whole bunch at once + await asyncio.gather( + *[ + app.request( + device, + profile=260, + cluster=1, + src_ep=1, + dst_ep=1, + sequence=seq, + data=b"\x00", + ) + for seq in range(10) + ] + ) + + assert in_flight_requests == 0 + assert did_lock + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_request_cancellation_shielding( + make_application, mocker, event_loop): + """Test request cancellation shielding.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + await app.startup(auto_form=False) + + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) + + ep = device.add_endpoint(1) + ep.status = zigpy.endpoint.Status.ZDO_INIT + ep.profile_id = 260 + ep.add_input_cluster(6) + + frame_control_byte = 0x18 + tsn = 0x01 + command_id = 0x01 + + payload = [frame_control_byte, tsn, command_id] + payload_length = len(payload) + + # The data confirm timeout must be shorter than the ARSP timeout + mocker.spy(app._api, "_unhandled_command") + + delayed_reply_sent = event_loop.create_future() + + def delayed_reply(req): + async def inner(): + await asyncio.sleep(0.5) + await zboss_server.send( + c.APS.DataIndication.Ind( + ParamLength=21, + PayloadLength=payload_length, + FrameFC=t.APSFrameFC(0x01), + SrcAddr=t.NWK(0xAABB), + DstAddr=t.NWK(0x1234), + GrpAddr=t.NWK(0x5678), + DstEndpoint=1, + SrcEndpoint=1, + ClusterId=6, + ProfileId=260, + PacketCounter=10, + SrcMACAddr=t.NWK(0xAABB), + DstMACAddr=t.NWK(0x1234), + LQI=255, + RSSI=-70, + KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(payload) + ) + ) + delayed_reply_sent.set_result(True) + + asyncio.create_task(inner()) + + data_req = zboss_server.reply_once_to( + c.APS.DataReq.Req( + TSN=1, ParamLength=21, DataLength=3, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + ProfileID=260, ClusterId=6, DstEndpoint=1, SrcEndpoint=1, Radius=0, + DstAddrMode=zigpy_t.AddrMode.NWK, + TxOptions=c.aps.TransmitOptions.NONE, + UseAlias=t.Bool.false, AliasSrcAddr=0x0000, AliasSeqNbr=0, + Payload=[1, 1, 1]), + responses=[ + c.APS.DataReq.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + DstEndpoint=1, + SrcEndpoint=1, + TxTime=1, + DstAddrMode=zigpy_t.AddrMode.NWK + ), + delayed_reply, + ], + ) + + with pytest.raises(asyncio.TimeoutError): + # Turn on the light + await device.request( + 260, + 6, + 1, + 1, + 1, + b'\x01\x01\x01', + expect_reply=True, + timeout=0.1, + ) + + await data_req + await delayed_reply_sent + + assert app._api._unhandled_command.call_count == 0 + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_send_security_and_packet_source_route(make_application, mocker): + """Test sending security and packet source route.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + packet = zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=app.state.node_info.nwk + ), + src_ep=0x9A, + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=0xEEFF + ), + dst_ep=0xBC, + tsn=0xDE, + profile_id=0x1234, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test data"), + extended_timeout=False, + tx_options=( + zigpy_t.TransmitOptions.ACK | + zigpy_t.TransmitOptions.APS_Encryption + ), + source_route=[0xAABB, 0xCCDD], + ) + + data_req = zboss_server.reply_once_to( + request=c.APS.DataReq.Req( + TSN=222, ParamLength=21, DataLength=9, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:ee:ff"), + ProfileID=4660, ClusterId=6, DstEndpoint=188, SrcEndpoint=154, + Radius=0, DstAddrMode=zigpy_t.AddrMode.NWK, + TxOptions=( + c.aps.TransmitOptions.SECURITY_ENABLED | + c.aps.TransmitOptions.ACK_ENABLED + ), + UseAlias=t.Bool.false, AliasSrcAddr=0x0000, AliasSeqNbr=0, + Payload=[116, 101, 115, 116, 32, 100, 97, 116, 97]), + responses=[ + c.APS.DataReq.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + DstEndpoint=1, + SrcEndpoint=1, + TxTime=1, + DstAddrMode=zigpy_t.AddrMode.NWK + ), + ], + ) + + await app.send_packet(packet) + req = await data_req + assert ( + c.aps.TransmitOptions.SECURITY_ENABLED + in c.aps.TransmitOptions(req.TxOptions) + ) + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_send_packet_failure_disconnected(make_application, mocker): + """Test sending packet failure at disconnect.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + app._api = None + + packet = zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=0x0000), + src_ep=0x9A, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=0xEEFF), + dst_ep=0xBC, + tsn=0xDE, + profile_id=0x1234, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test data"), + ) + + with pytest.raises(zigpy.exceptions.DeliveryError) as excinfo: + await app.send_packet(packet) + + assert "Coordinator is disconnected" in str(excinfo.value) + + await app.shutdown() diff --git a/tests/application/test_startup.py b/tests/application/test_startup.py new file mode 100644 index 0000000..451611e --- /dev/null +++ b/tests/application/test_startup.py @@ -0,0 +1,365 @@ +"""Test application startup.""" +from unittest.mock import AsyncMock as CoroutineMock + +import pytest +from zigpy.exceptions import NetworkNotFormed + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t +from zigpy_zboss.api import ZBOSS + +from ..conftest import BaseZbossDevice, BaseZbossGenericDevice + + +@pytest.mark.asyncio +async def test_info(make_application, caplog): + """Test network information.""" + app, zboss_server = make_application( + server_cls=BaseZbossGenericDevice, active_sequence=True + ) + + pan_id = 0x5679 + ext_pan_id = t.EUI64.convert("00:11:22:33:44:55:66:77") + channel = 11 + channel_mask = 0x07fff800 + parent_address = t.NWK(0x5679) + coordinator_version = 1 + # Simulate responses for each request in load_network_info + zboss_server.reply_once_to( + request=c.NcpConfig.GetZigbeeRole.Req(TSN=1), + responses=[c.NcpConfig.GetZigbeeRole.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetZigbeeRole.Req(TSN=1), + responses=[c.NcpConfig.GetZigbeeRole.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetJoinStatus.Req(TSN=2), + responses=[c.NcpConfig.GetJoinStatus.Rsp( + TSN=2, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + Joined=1)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetShortAddr.Req(TSN=3), + responses=[c.NcpConfig.GetShortAddr.Rsp( + TSN=3, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + NWKAddr=t.NWK(0xAABB))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetLocalIEEE.Req(TSN=4, MacInterfaceNum=0), + responses=[c.NcpConfig.GetLocalIEEE.Rsp( + TSN=4, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + MacInterfaceNum=0, + IEEE=t.EUI64(range(8)))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetZigbeeRole.Req(TSN=5), + responses=[c.NcpConfig.GetZigbeeRole.Rsp( + TSN=5, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(2))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetModuleVersion.Req(TSN=6), + responses=[c.NcpConfig.GetModuleVersion.Rsp( + TSN=6, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, # Example status code + FWVersion=1, # Example firmware version + StackVersion=2, # Example stack version + ProtocolVersion=3 # Example protocol version + )] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetExtendedPANID.Req(TSN=7), + responses=[c.NcpConfig.GetExtendedPANID.Rsp( + TSN=7, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + ExtendedPANID=ext_pan_id)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetShortPANID.Req(TSN=8), + responses=[c.NcpConfig.GetShortPANID.Rsp( + TSN=8, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + PANID=t.PanId(pan_id))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetCurrentChannel.Req(TSN=9), + responses=[c.NcpConfig.GetCurrentChannel.Rsp( + TSN=9, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + Channel=channel, Page=0)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetChannelMask.Req(TSN=10), + responses=[c.NcpConfig.GetChannelMask.Rsp( + TSN=10, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + ChannelList=[t.ChannelEntry(page=1, channel_mask=channel_mask)])] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetTrustCenterAddr.Req( + TSN=13 + ), + responses=[c.NcpConfig.GetTrustCenterAddr.Rsp( + TSN=13, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + TCIEEE=t.EUI64.convert("00:11:22:33:44:55:66:77") + # Example Trust Center IEEE address + )] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetRxOnWhenIdle.Req(TSN=14), + responses=[c.NcpConfig.GetRxOnWhenIdle.Rsp( + TSN=14, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + RxOnWhenIdle=1)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetEDTimeout.Req(TSN=15), + responses=[c.NcpConfig.GetEDTimeout.Rsp( + TSN=15, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + Timeout=t.TimeoutIndex(0x00))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetMaxChildren.Req(TSN=16), + responses=[c.NcpConfig.GetMaxChildren.Rsp( + TSN=16, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + ChildrenNbr=10)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetAuthenticationStatus.Req(TSN=17), + responses=[c.NcpConfig.GetAuthenticationStatus.Rsp( + TSN=17, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + Authenticated=True)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetParentAddr.Req(TSN=18), + responses=[c.NcpConfig.GetParentAddr.Rsp( + TSN=18, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + NWKParentAddr=parent_address)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetCoordinatorVersion.Req(TSN=19), + responses=[c.NcpConfig.GetCoordinatorVersion.Rsp( + TSN=19, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + CoordinatorVersion=coordinator_version)] + ) + + zboss_server.reply_once_to( + request=c.ZDO.PermitJoin.Req( + TSN=21, + DestNWK=t.NWK(0x0000), + PermitDuration=t.uint8_t(0), + TCSignificance=t.uint8_t(0x01), + ), + responses=[c.ZDO.PermitJoin.Rsp( + TSN=21, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + )] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.NCPModuleReset.Req( + TSN=22, Option=t.ResetOptions(0) + ), + responses=[c.NcpConfig.NCPModuleReset.Rsp( + TSN=22, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK + )] + ) + + await app.startup(auto_form=False) + + assert app.state.network_info.pan_id == 0x5679 + assert app.state.network_info.extended_pan_id == t.EUI64( + ext_pan_id.serialize()[::-1]) + assert app.state.network_info.channel == channel + assert app.state.network_info.channel_mask == channel_mask + assert app.state.network_info.network_key.seq == 1 + zboss_stack_specific = app.state.network_info.stack_specific["zboss"] + assert zboss_stack_specific["parent_nwk"] == parent_address + assert zboss_stack_specific["authenticated"] == 1 + assert zboss_stack_specific["coordinator_version"] == coordinator_version + + # Anything to make sure it's set + assert app._device.node_desc.maximum_outgoing_transfer_size == 82 + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_endpoints(make_application): + """Test endpoints.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + endpoints = [] + zboss_server.register_indication_listener( + c.ZDO.PermitJoin.Req(partial=True), endpoints.append + ) + + await app.startup(auto_form=False) + + # We currently just register one endpoint + assert len(endpoints) == 1 + assert 1 in app._device.endpoints + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_not_configured(make_application): + """Test device not configured.""" + app, zboss_server = make_application( + server_cls=BaseZbossGenericDevice, active_sequence=True + ) + + # Simulate responses for each request in load_network_info + zboss_server.reply_once_to( + request=c.NcpConfig.GetZigbeeRole.Req(TSN=1), + responses=[c.NcpConfig.GetZigbeeRole.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetZigbeeRole.Req(TSN=1), + responses=[c.NcpConfig.GetZigbeeRole.Rsp( + TSN=1, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1))] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.GetJoinStatus.Req(TSN=2), + responses=[c.NcpConfig.GetJoinStatus.Rsp( + TSN=2, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK, + Joined=0)] + ) + + zboss_server.reply_once_to( + request=c.NcpConfig.NCPModuleReset.Req( + TSN=3, Option=t.ResetOptions(0) + ), + responses=[c.NcpConfig.NCPModuleReset.Rsp( + TSN=3, + StatusCat=t.StatusCategory(4), + StatusCode=t.StatusCodeGeneric.OK + )] + ) + + # We cannot start the application if Z-Stack + # is not configured and without auto_form + with pytest.raises(NetworkNotFormed): + await app.startup(auto_form=False) + + +@pytest.mark.asyncio +async def test_reset(make_application, mocker): + """Test application reset.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + # `_reset` should be called at least once + # to put the radio into a consistent state + mocker.spy(ZBOSS, "reset") + assert ZBOSS.reset.call_count == 0 + + await app.startup() + await app.shutdown() + + assert ZBOSS.reset.call_count >= 1 + + +@pytest.mark.asyncio +async def test_auto_form_unnecessary(make_application, mocker): + """Test unnecessary auto form.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + mocker.patch.object(app, "form_network", new=CoroutineMock()) + + await app.startup(auto_form=True) + + assert app.form_network.call_count == 0 + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_auto_form_necessary(make_application, mocker): + """Test necessary auto form.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + assert app.state.network_info.channel == 0 + assert app.state.network_info.channel_mask == t.Channels.NO_CHANNELS + + await app.startup(auto_form=True) + + assert app.state.network_info.channel != 0 + assert app.state.network_info.channel_mask != t.Channels.NO_CHANNELS + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_concurrency_auto_config(make_application): + """Test auto config concurrency.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.connect() + await app.start_network() + + assert app._concurrent_requests_semaphore.max_value == 8 diff --git a/tests/application/test_zdo_requests.py b/tests/application/test_zdo_requests.py new file mode 100644 index 0000000..fec1c09 --- /dev/null +++ b/tests/application/test_zdo_requests.py @@ -0,0 +1,87 @@ +"""Test application ZDO request.""" +import asyncio + +import pytest +import zigpy.types as z_types +import zigpy.zdo.types as zdo_t + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t + +from ..conftest import BaseZbossDevice + + +@pytest.mark.asyncio +async def test_mgmt_nwk_update_req(make_application, mocker): + """Test ZDO_MGMT_NWK_UPDATE_REQ request.""" + mocker.patch( + "zigpy.application.CHANNEL_CHANGE_SETTINGS_RELOAD_DELAY_S", 0.1 + ) + + app, zboss_server = make_application(server_cls=BaseZbossDevice) + + new_channel = 11 + old_channel = 1 + + async def update_channel(req): + # Wait a bit before updating + await asyncio.sleep(0.1) + zboss_server.new_channel = new_channel + + yield + + zboss_server.reply_once_to( + request=c.APS.DataReq.Req( + TSN=123, + ParamLength=21, + DataLength=3, + ProfileID=260, + ClusterId=zdo_t.ZDOCmd.Mgmt_NWK_Update_req, + DstEndpoint=0, + partial=True + ), + responses=[c.APS.DataReq.Rsp( + TSN=123, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=t.EUI64.convert("00:00:00:00:00:00:aa:bb"), + DstEndpoint=1, + SrcEndpoint=1, + TxTime=1, + DstAddrMode=z_types.AddrMode.Group, + )], + ) + nwk_update_req = zboss_server.reply_once_to( + request=c.ZDO.MgmtNwkUpdate.Req( + TSN=123, + DstNWK=t.NWK(0x0000), + ScanChannelMask=t.Channels.from_channel_list([new_channel]), + ScanDuration=zdo_t.NwkUpdate.CHANNEL_CHANGE_REQ, + ScanCount=0, + MgrAddr=0x0000, + ), + responses=[ + c.ZDO.MgmtNwkUpdate.Rsp( + TSN=123, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ScannedChannels=t.Channels.from_channel_list([new_channel]), + TotalTransmissions=1, + TransmissionFailures=0, + EnergyValues=c.zdo.EnergyValues(t.LVList([1])), + ), + update_channel, + ], + ) + + await app.startup(auto_form=False) + + assert app.state.network_info.channel == old_channel + + await app.move_network_to_channel(new_channel=new_channel) + + await nwk_update_req + + assert app.state.network_info.channel == new_channel + + await app.shutdown() diff --git a/tests/application/test_zigpy_callbacks.py b/tests/application/test_zigpy_callbacks.py new file mode 100644 index 0000000..2d8fe5e --- /dev/null +++ b/tests/application/test_zigpy_callbacks.py @@ -0,0 +1,324 @@ +"""Test zigpy callbacks.""" +import asyncio + +import pytest +import zigpy.types as zigpy_t +import zigpy.zdo.types as zdo_t + +import zigpy_zboss.commands as c +import zigpy_zboss.types as t + +from ..conftest import BaseZbossDevice, serialize_zdo_command + + +@pytest.mark.asyncio +async def test_on_zdo_device_announce_nwk_change(make_application, mocker): + """Test device announce network address change.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.spy(app, "handle_join") + mocker.patch.object(app, "handle_message") + + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xFA9E) + new_nwk = device.nwk + 1 + + payload = bytearray(serialize_zdo_command( + command_id=zdo_t.ZDOCmd.Device_annce, + NWKAddr=new_nwk, + IEEEAddr=device.ieee, + Capability=t.MACCapability.DeviceType, + Status=t.DeviceUpdateStatus.tc_rejoin, + )) + payload_length = len(payload) + + # Assume its NWK changed and we're just finding out + await zboss_server.send( + c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=payload_length, + FrameFC=t.APSFrameFC(0x01), + SrcAddr=t.NWK(0x0001), DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x0000), DstEndpoint=1, + SrcEndpoint=1, ClusterId=zdo_t.ZDOCmd.Device_annce, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0x0000), + LQI=255, RSSI=-70, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(payload) + ) + ) + + await zboss_server.send( + c.ZDO.DevAnnceInd.Ind( + NWK=new_nwk, + IEEE=device.ieee, + MacCap=1, + ) + ) + + await asyncio.sleep(0.1) + + app.handle_join.assert_called_once_with( + nwk=new_nwk, ieee=device.ieee, parent_nwk=None + ) + + # The device's NWK has been updated + assert device.nwk == new_nwk + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_on_zdo_device_leave_callback(make_application, mocker): + """Test ZDO device leave indication.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.patch.object(app, "handle_leave") + + nwk = 0xAABB + ieee = t.EUI64(range(8)) + + await zboss_server.send( + c.NWK.NwkLeaveInd.Ind( + IEEE=ieee, Rejoin=0 + ) + ) + app.handle_leave.assert_called_once_with(nwk=nwk, ieee=ieee) + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_on_af_message_callback(make_application, mocker): + """Test AF message indication.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) + + af_message = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=len(b"test"), + FrameFC=t.APSFrameFC(0x01), + SrcAddr=device.nwk, DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x0000), DstEndpoint=1, + SrcEndpoint=4, ClusterId=2, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0x0000), + LQI=19, RSSI=0, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(b"test") + ) + + # Normal message + await zboss_server.send(af_message) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + _call = app.packet_received.call_args[0][0] + assert _call.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=device.nwk, + ) + assert _call.src_ep == 4 + assert _call.dst == zigpy_t.AddrModeAddress( + zigpy_t.AddrMode.NWK, app.state.node_info.nwk + ) + assert _call.dst_ep == 1 + assert _call.cluster_id == 2 + assert _call.data.serialize() == b"test" + assert _call.lqi == 19 + assert _call.rssi == 0 + assert _call.profile_id == 260 + + app.packet_received.reset_mock() + + zll_message = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=len(b"test"), + FrameFC=t.APSFrameFC(0x01), + SrcAddr=device.nwk, DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x0000), DstEndpoint=2, + SrcEndpoint=4, ClusterId=2, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0x0000), + LQI=19, RSSI=0, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(b"test") + ) + + # ZLL message + await zboss_server.send(zll_message) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + _call = app.packet_received.call_args[0][0] + assert _call.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=device.nwk + ) + assert _call.src_ep == 4 + assert _call.dst == zigpy_t.AddrModeAddress( + zigpy_t.AddrMode.NWK, app.state.node_info.nwk + ) + assert _call.dst_ep == 2 + assert _call.cluster_id == 2 + assert _call.data.serialize() == b"test" + assert _call.lqi == 19 + assert _call.rssi == 0 + assert _call.profile_id == 260 + + app.packet_received.reset_mock() + + unknown_message = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=len(b"test"), + FrameFC=t.APSFrameFC(0x01), + SrcAddr=device.nwk, DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x0000), DstEndpoint=3, + SrcEndpoint=4, ClusterId=2, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0x0000), + LQI=19, RSSI=0, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(b"test") + ) + + # Message on an unknown endpoint (is this possible?) + await zboss_server.send(unknown_message) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + _call = app.packet_received.call_args[0][0] + assert _call.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=device.nwk + ) + assert _call.src_ep == 4 + assert _call.dst == zigpy_t.AddrModeAddress( + zigpy_t.AddrMode.NWK, app.state.node_info.nwk + ) + assert _call.dst_ep == 3 + assert _call.cluster_id == 2 + assert _call.data.serialize() == b"test" + assert _call.lqi == 19 + assert _call.rssi == 0 + assert _call.profile_id == 260 + + app.packet_received.reset_mock() + + +@pytest.mark.asyncio +async def test_receive_zdo_broadcast(make_application, mocker): + """Test receive ZDO broadcast.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") + + zdo_callback = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=len(b"bogus"), + FrameFC=t.APSFrameFC.Broadcast, + SrcAddr=t.NWK(0x35D9), DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x0000), DstEndpoint=0, + SrcEndpoint=0, ClusterId=19, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0xFFFF), + LQI=19, RSSI=0, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(b"bogus") + ) + await zboss_server.send(zdo_callback) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + packet = app.packet_received.mock_calls[0].args[0] + assert packet.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=0x35D9 + ) + assert packet.src_ep == 0x00 + assert packet.dst == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Broadcast, + address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + assert packet.dst_ep == 0x00 + assert packet.cluster_id == zdo_callback.ClusterId + assert packet.data.serialize() == zdo_callback.Payload.serialize() + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_receive_af_broadcast(make_application, mocker): + """Test receive AF broadcast.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") + + payload = b"\x11\xA6\x00\x74\xB5\x7C\x00\x02\x5F" + + af_callback = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=len(payload), + FrameFC=t.APSFrameFC.Broadcast, + SrcAddr=t.NWK(0x1234), DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x0000), DstEndpoint=2, + SrcEndpoint=254, ClusterId=4096, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0xFFFF), + LQI=19, RSSI=0, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(payload) + ) + await zboss_server.send(af_callback) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + packet = app.packet_received.mock_calls[0].args[0] + assert packet.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=0x1234, + ) + assert packet.src_ep == af_callback.SrcEndpoint + assert packet.dst == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Broadcast, + address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + assert packet.dst_ep == af_callback.DstEndpoint + assert packet.cluster_id == af_callback.ClusterId + assert packet.lqi == af_callback.LQI + assert packet.data.serialize() == af_callback.Payload.serialize() + + await app.shutdown() + + +@pytest.mark.asyncio +async def test_receive_af_group(make_application, mocker): + """Test receive AF group.""" + app, zboss_server = make_application(server_cls=BaseZbossDevice) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") + + payload = b"\x11\xA6\x00\x74\xB5\x7C\x00\x02\x5F" + + af_callback = c.APS.DataIndication.Ind( + ParamLength=21, PayloadLength=len(payload), + FrameFC=t.APSFrameFC.Group, + SrcAddr=t.NWK(0x1234), DstAddr=t.NWK(0x0000), + GrpAddr=t.NWK(0x1234), DstEndpoint=0, + SrcEndpoint=254, ClusterId=4096, ProfileId=260, + PacketCounter=10, SrcMACAddr=t.NWK(0x0000), + DstMACAddr=t.NWK(0xFFFF), + LQI=19, RSSI=0, KeySrcAndAttr=t.ApsAttributes(0x01), + Payload=t.Payload(payload) + ) + await zboss_server.send(af_callback) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + packet = app.packet_received.mock_calls[0].args[0] + assert packet.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=0x1234, + ) + assert packet.src_ep == af_callback.SrcEndpoint + assert packet.dst == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Group, address=0x1234 + ) + assert packet.cluster_id == af_callback.ClusterId + assert packet.lqi == af_callback.LQI + assert packet.data.serialize() == af_callback.Payload.serialize() + + await app.shutdown() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8171b59 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,1057 @@ +"""Shared fixtures and utilities for testing zigpy-zboss.""" +import asyncio +import gc +import inspect +import logging +import sys +import typing +from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch + +import pytest +import zigpy +from zigpy.zdo import types as zdo_t + +import zigpy_zboss.commands as c +import zigpy_zboss.config as conf +import zigpy_zboss.types as t +from zigpy_zboss.api import ZBOSS +from zigpy_zboss.uart import ZbossNcpProtocol +from zigpy_zboss.zigbee.application import ControllerApplication + +LOGGER = logging.getLogger(__name__) + +FAKE_SERIAL_PORT = "/dev/ttyFAKE0" + + +# Globally handle async tests and error on unawaited coroutines +def pytest_collection_modifyitems(session, config, items): + """Modify collection items.""" + for item in items: + item.add_marker( + pytest.mark.filterwarnings( + "error::pytest.PytestUnraisableExceptionWarning" + ) + ) + item.add_marker(pytest.mark.filterwarnings("error::RuntimeWarning")) + + +@pytest.hookimpl(trylast=True) +def pytest_fixture_post_finalizer(fixturedef, request) -> None: + """Post fixture teardown.""" + if fixturedef.argname != "event_loop": + return + + policy = asyncio.get_event_loop_policy() + try: + loop = policy.get_event_loop() + except RuntimeError: + loop = None + if loop is not None: + # Cleanup code based on the implementation of asyncio.run() + try: + if not loop.is_closed(): + asyncio.runners._cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if sys.version_info >= (3, 9): + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + loop.close() + new_loop = policy.new_event_loop() # Replace existing event loop + # Ensure subsequent calls to get_event_loop() succeed + policy.set_event_loop(new_loop) + + +@pytest.fixture +def event_loop( + request: pytest.FixtureRequest, +) -> typing.Iterator[asyncio.AbstractEventLoop]: + """Create an instance of the default event loop for each test case.""" + yield asyncio.get_event_loop_policy().new_event_loop() + # Call the garbage collector to trigger ResourceWarning's as soon + # as possible (these are triggered in various __del__ methods). + # Without this, resources opened in one test can fail other tests + # when the warning is generated. + gc.collect() + # Event loop cleanup handled by pytest_fixture_post_finalizer + + +class ForwardingSerialTransport: + """Serial transport that hooks directly into a protocol.""" + + def __init__(self, protocol): + """Initailize.""" + self.protocol = protocol + self._is_connected = False + self.other = None + + self.serial = Mock() + self.serial.name = FAKE_SERIAL_PORT + self.serial.baudrate = 45678 + type(self.serial).dtr = self._mock_dtr_prop = PropertyMock( + return_value=None + ) + type(self.serial).rts = self._mock_rts_prop = PropertyMock( + return_value=None + ) + + def _connect(self): + assert not self._is_connected + self._is_connected = True + self.other.protocol.connection_made(self) + + def write(self, data): + """Write.""" + assert self._is_connected + self.protocol.data_received(data) + + def close( + self, *, error=ValueError("Connection was closed") # noqa: B008 + ): + """Close.""" + LOGGER.debug("Closing %s", self) + + if not self._is_connected: + return + + self._is_connected = False + + # Our own protocol gets gracefully closed + self.other.close(error=None) + + # The protocol we're forwarding to gets the error + self.protocol.connection_lost(error) + + def __repr__(self): + """Representation.""" + return f"<{type(self).__name__} to {self.protocol}>" + + +def config_for_port_path(path): + """Port path configuration.""" + return conf.CONFIG_SCHEMA( + { + conf.CONF_DEVICE: {conf.CONF_DEVICE_PATH: path}, + conf.CONF_DEVICE_BAUDRATE: 115200, + conf.CONF_DEVICE_FLOW_CONTROL: None + } + ) + + +@pytest.fixture +def make_zboss_server(mocker): + """Instantiate a zboss server.""" + transports = [] + + def inner(server_cls, config=None, shorten_delays=True): + if config is None: + config = config_for_port_path(FAKE_SERIAL_PORT) + + if shorten_delays: + mocker.patch( + "zigpy_zboss.api.AFTER_BOOTLOADER_SKIP_BYTE_DELAY", 0.001 + ) + mocker.patch("zigpy_zboss.api.BOOTLOADER_PIN_TOGGLE_DELAY", 0.001) + + server = server_cls(config) + server._transports = transports + + server.port_path = FAKE_SERIAL_PORT + server._uart = None + + def passthrough_serial_conn( + loop, protocol_factory, url, *args, **kwargs + ): + LOGGER.info("Intercepting serial connection to %s", url) + + assert url == FAKE_SERIAL_PORT + + # No double connections! + if any([t._is_connected for t in transports]): + raise RuntimeError( + "Cannot open two connections to the same serial port" + ) + if server._uart is None: + server._uart = ZbossNcpProtocol( + config[conf.CONF_DEVICE], server + ) + mocker.spy(server._uart, "data_received") + + client_protocol = protocol_factory() + + # Client writes go to the server + client_transport = ForwardingSerialTransport(server._uart) + transports.append(client_transport) + + # Server writes go to the client + server_transport = ForwardingSerialTransport(client_protocol) + + # Notify them of one another + server_transport.other = client_transport + client_transport.other = server_transport + + # And finally connect both simultaneously + server_transport._connect() + client_transport._connect() + + fut = loop.create_future() + fut.set_result((client_transport, client_protocol)) + + return fut + + mocker.patch( + "zigpy.serial.pyserial_asyncio.create_serial_connection", + new=passthrough_serial_conn + ) + + # So we don't have to import it every time + server.serial_port = FAKE_SERIAL_PORT + + return server + + yield inner + + +@pytest.fixture +def make_connected_zboss(make_zboss_server, mocker): + """Make a connection fixture.""" + async def inner(server_cls): + config = conf.CONFIG_SCHEMA( + { + conf.CONF_DEVICE: {conf.CONF_DEVICE_PATH: FAKE_SERIAL_PORT}, + } + ) + + zboss = ZBOSS(config) + zboss_server = make_zboss_server(server_cls=server_cls) + + await zboss.connect() + + zboss.nvram.align_structs = server_cls.align_structs + zboss.version = server_cls.version + + return zboss, zboss_server + + return inner + + +@pytest.fixture +def connected_zboss(event_loop, make_connected_zboss): + """Zboss connected fixture.""" + zboss, zboss_server = event_loop.run_until_complete( + make_connected_zboss(BaseServerZBOSS) + ) + yield zboss, zboss_server + zboss.close() + + +def reply_to(request): + """Reply to decorator.""" + def inner(function): + if not hasattr(function, "_reply_to"): + function._reply_to = [] + + function._reply_to.append(request) + + return function + + return inner + + +def serialize_zdo_command(command_id, **kwargs): + """ZDO command serialization.""" + field_names, field_types = zdo_t.CLUSTERS[command_id] + + return t.Bytes(zigpy.types.serialize(kwargs.values(), field_types)) + + +def deserialize_zdo_command(command_id, data): + """ZDO command deserialization.""" + field_names, field_types = zdo_t.CLUSTERS[command_id] + args, data = zigpy.types.deserialize(data, field_types) + + return dict(zip(field_names, args)) + + +class BaseServerZBOSS(ZBOSS): + """Base ZBOSS server.""" + + align_structs = False + version = None + + async def _flatten_responses(self, request, responses): + if responses is None: + return + elif isinstance(responses, t.CommandBase): + yield responses + elif inspect.iscoroutinefunction(responses): + async for rsp in responses(request): + yield rsp + elif inspect.isasyncgen(responses): + async for rsp in responses: + yield rsp + elif callable(responses): + async for rsp in self._flatten_responses( + request, responses(request) + ): + yield rsp + else: + for response in responses: + async for rsp in self._flatten_responses(request, response): + yield rsp + + async def _send_responses(self, request, responses): + async for response in self._flatten_responses(request, responses): + await asyncio.sleep(0.001) + LOGGER.debug( + "Replying to %s with %s", request, response + ) + await self.send(response) + + def reply_once_to(self, request, responses, *, override=False): + """Reply once to.""" + if override: + self._listeners[request.header].clear() + + request_future = self.wait_for_response(request) + + async def replier(): + request = await request_future + await self._send_responses(request, responses) + + return request + + return asyncio.create_task(replier()) + + def reply_to(self, request, responses, *, override=False): + """Reply to.""" + if override: + self._listeners[request.header].clear() + + async def callback(request): + callback.call_count += 1 + await self._send_responses(request, responses) + + callback.call_count = 0 + + self.register_indication_listener( + request, lambda r: asyncio.create_task(callback(r)) + ) + + return callback + + async def send(self, response): + """Send.""" + if response is not None and self._uart is not None: + await self._uart.send(response.to_frame(align=self.align_structs)) + + def close(self): + """Close.""" + # We don't clear listeners on shutdown + with patch.object(self, "_listeners", {}): + return super().close() + + +def simple_deepcopy(d): + """Get a deepcopy.""" + if not hasattr(d, "copy"): + return d + + if isinstance(d, (list, tuple)): + return type(d)(map(simple_deepcopy, d)) + elif isinstance(d, dict): + return type(d)( + {simple_deepcopy(k): simple_deepcopy(v) for k, v in d.items()} + ) + else: + return d.copy() + + +def merge_dicts(a, b): + """Merge dicts.""" + c = simple_deepcopy(a) + + for key, value in b.items(): + if isinstance(value, dict): + c[key] = merge_dicts(c.get(key, {}), value) + else: + c[key] = value + + return c + + +@pytest.fixture +def make_application(make_zboss_server): + """Application fixture.""" + def inner( + server_cls, + client_config=None, + server_config=None, + active_sequence=False, + **kwargs, + ): + default = config_for_port_path(FAKE_SERIAL_PORT) + + client_config = merge_dicts(default, client_config or {}) + server_config = merge_dicts(default, server_config or {}) + + app = ControllerApplication(client_config) + + def add_initialized_device(self, *args, **kwargs): + device = self.add_device(*args, **kwargs) + device.status = zigpy.device.Status.ENDPOINTS_INIT + device.model = "Model" + device.manufacturer = "Manufacturer" + + device.node_desc = zdo_t.NodeDescriptor( + logical_type=zdo_t.LogicalType.Router, + complex_descriptor_available=0, + user_descriptor_available=0, + reserved=0, + aps_flags=0, + frequency_band=zdo_t.NodeDescriptor.FrequencyBand.Freq2400MHz, + mac_capability_flags=142, + manufacturer_code=4476, + maximum_buffer_size=82, + maximum_incoming_transfer_size=82, + server_mask=11264, + maximum_outgoing_transfer_size=82, + descriptor_capability_field=0, + ) + + ep = device.add_endpoint(1) + ep.status = zigpy.endpoint.Status.ZDO_INIT + + return device + + async def start_network(self): + dev = self.add_initialized_device( + ieee=t.EUI64(range(8)), nwk=0xAABB + ) + dev.model = "Coordinator Model" + dev.manufacturer = "Coordinator Manufacturer" + + dev.zdo.Mgmt_NWK_Update_req = AsyncMock( + return_value=[ + zdo_t.Status.SUCCESS, + t.Channels.ALL_CHANNELS, + 0, + 0, + [80] * 16, + ] + ) + + async def energy_scan(self, channels, duration_exp, count): + return {self.state.network_info.channel: 0x1234} + + app.add_initialized_device = add_initialized_device.__get__(app) + app.start_network = start_network.__get__(app) + app.energy_scan = energy_scan.__get__(app) + + app.device_initialized = Mock(wraps=app.device_initialized) + app.listener_event = Mock(wraps=app.listener_event) + if not active_sequence: + app.get_sequence = MagicMock( + wraps=app.get_sequence, return_value=123 + ) + app.send_packet = AsyncMock(wraps=app.send_packet) + app.write_network_info = AsyncMock(wraps=app.write_network_info) + + server = make_zboss_server( + server_cls=server_cls, config=server_config, **kwargs + ) + + return app, server + + return inner + + +def zdo_request_matcher( + dst_addr, command_id: t.uint16_t, **kwargs): + """Request matcher.""" + zdo_kwargs = {k: v for k, v in kwargs.items() if k.startswith("zdo_")} + + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("zdo_")} + kwargs.setdefault("DstEndpoint", 0x00) + kwargs.setdefault("SrcEndpoint", 0x00) + kwargs.setdefault("Radius", None) + + return c.APS.DataReq.Req( + DstAddr=t.EUI64.convert("00124b0001ab89cd"), + ClusterId=command_id, + Payload=t.Payload( + bytes([kwargs["TSN"]]) + + serialize_zdo_command(command_id, **zdo_kwargs) + ), + **kwargs, + partial=True, + ) + + +class BaseZbossDevice(BaseServerZBOSS): + """Base ZBOSS Device.""" + + def __init__(self, *args, **kwargs): + """Initialize.""" + super().__init__(*args, **kwargs) + self.active_endpoints = [] + self._nvram = {} + self._orig_nvram = {} + self.new_channel = 0 + self.device_state = 0x00 + self.zdo_callbacks = set() + for name in dir(self): + func = getattr(self, name) + for req in getattr(func, "_reply_to", []): + self.reply_to(request=req, responses=[func]) + + def connection_lost(self, exc): + """Lost connection.""" + self.active_endpoints.clear() + return super().connection_lost(exc) + + @reply_to(c.NcpConfig.GetJoinStatus.Req(partial=True)) + def get_join_status(self, request): + """Handle get join status.""" + return c.NcpConfig.GetJoinStatus.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + Joined=0x01 # Assume device is joined for this example + ) + + @reply_to(c.NcpConfig.NCPModuleReset.Req(partial=True)) + def get_ncp_reset(self, request): + """Handle NCP reset.""" + return c.NcpConfig.NCPModuleReset.Rsp( + TSN=0xFF, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK + ) + + @reply_to(c.NcpConfig.GetShortAddr.Req(partial=True)) + def get_short_addr(self, request): + """Handle get short address.""" + return c.NcpConfig.GetShortAddr.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + NWKAddr=t.NWK(0x1234) # Example NWK address + ) + + @reply_to(c.APS.DataReq.Req(partial=True, DstEndpoint=0)) + def on_zdo_request(self, req): + """Handle APS Data request.""" + return c.APS.DataReq.Rsp( + TSN=req.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DstAddr=req.DstAddr, + DstEndpoint=req.DstEndpoint, + SrcEndpoint=req.SrcEndpoint, + TxTime=1, + DstAddrMode=req.DstAddrMode, + ) + + @reply_to(c.NcpConfig.GetLocalIEEE.Req(partial=True)) + def get_local_ieee(self, request): + """Handle get local IEEE.""" + return c.NcpConfig.GetLocalIEEE.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + MacInterfaceNum=request.MacInterfaceNum, + IEEE=t.EUI64([0, 1, 2, 3, 4, 5, 6, 7]) # Example IEEE address + ) + + @reply_to(c.NcpConfig.GetZigbeeRole.Req(partial=True)) + def get_zigbee_role(self, request): + """Handle get zigbee role.""" + return c.NcpConfig.GetZigbeeRole.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) # Example role + ) + + @reply_to(c.NcpConfig.GetExtendedPANID.Req(partial=True)) + def get_extended_panid(self, request): + """Handle get extended PANID.""" + return c.NcpConfig.GetExtendedPANID.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ExtendedPANID=t.EUI64.convert("00124b0001ab89cd") # Example PAN ID + ) + + @reply_to(c.ZDO.PermitJoin.Req(partial=True)) + def get_permit_join(self, request): + """Handle get permit join.""" + return c.ZDO.PermitJoin.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ) + + @reply_to(c.NcpConfig.GetShortPANID.Req(partial=True)) + def get_short_panid(self, request): + """Handle get short PANID.""" + return c.NcpConfig.GetShortPANID.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + PANID=t.PanId(0x5678) # Example short PAN ID + ) + + @reply_to(c.NcpConfig.GetCurrentChannel.Req(partial=True)) + def get_current_channel(self, request): + """Handle get current channel.""" + if self.new_channel != 0: + channel = self.new_channel + else: + channel = 1 + + return c.NcpConfig.GetCurrentChannel.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + Page=0, + Channel=t.Channels(channel) + ) + + @reply_to(c.NcpConfig.GetChannelMask.Req(partial=True)) + def get_channel_mask(self, request): + """Handle get channel mask.""" + return c.NcpConfig.GetChannelMask.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ChannelList=t.ChannelEntryList( + [t.ChannelEntry(page=1, channel_mask=0x07fff800)]) + ) # Example mask + + @reply_to(c.NcpConfig.ReadNVRAM.Req(partial=True)) + def read_nvram(self, request): + """Handle NVRAM read.""" + status_code = t.StatusCodeGeneric.ERROR + if request.DatasetId == t.DatasetId.ZB_NVRAM_COMMON_DATA: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSCommonData( + byte_count=100, + bitfield=1, + depth=1, + nwk_manager_addr=0x0000, + panid=0x1234, + network_addr=0x5678, + channel_mask=t.Channels(14), + aps_extended_panid=t.EUI64.convert("00:11:22:33:44:55:66:77"), + nwk_extended_panid=t.EUI64.convert("00:11:22:33:44:55:66:77"), + parent_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + tc_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + nwk_key=t.KeyData(b'\x01' * 16), + nwk_key_seq=0, + tc_standard_key=t.KeyData(b'\x02' * 16), + channel=15, + page=0, + mac_interface_table=t.MacInterfaceTable( + bitfield_0=0, + bitfield_1=1, + link_pwr_data_rate=250, + channel_in_use=11, + supported_channels=t.Channels(15) + ), + reserved=0 + ) + nvram_version = 3 + dataset_version = 1 + elif request.DatasetId == t.DatasetId.ZB_IB_COUNTERS: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSIbCounters( + byte_count=8, + nib_counter=100, # Example counter value + aib_counter=50 # Example counter value + ) + nvram_version = 1 + dataset_version = 1 + elif request.DatasetId == t.DatasetId.ZB_NVRAM_ADDR_MAP: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSNwkAddrMap( + header=t.NwkAddrMapHeader( + byte_count=100, + entry_count=2, + _align=0 + ), + items=[ + t.NwkAddrMapRecord( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + nwk_addr=0x1234, + index=1, + redirect_type=0, + redirect_ref=0, + _align=0 + ), + t.NwkAddrMapRecord( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:78"), + nwk_addr=0x5678, + index=2, + redirect_type=0, + redirect_ref=0, + _align=0 + ) + ] + ) + nvram_version = 2 + dataset_version = 1 + elif request.DatasetId == t.DatasetId.ZB_NVRAM_APS_SECURE_DATA: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSApsSecureKeys( + header=10, + items=[ + t.ApsSecureEntry( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + key=t.KeyData(b'\x03' * 16), + _unknown_1=0 + ), + t.ApsSecureEntry( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:78"), + key=t.KeyData(b'\x04' * 16), + _unknown_1=0 + ) + ] + ) + nvram_version = 1 + dataset_version = 1 + else: + dataset = t.NVRAMDataset(b'') + nvram_version = 1 + dataset_version = 1 + + return c.NcpConfig.ReadNVRAM.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=status_code, + NVRAMVersion=nvram_version, + DatasetId=t.DatasetId(request.DatasetId), + DatasetVersion=dataset_version, + Dataset=t.NVRAMDataset(dataset.serialize()) + ) + + @reply_to(c.NcpConfig.GetTrustCenterAddr.Req(partial=True)) + def get_trust_center_addr(self, request): + """Handle get trust center address.""" + return c.NcpConfig.GetTrustCenterAddr.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + TCIEEE=t.EUI64.convert("00:11:22:33:44:55:66:77") + # Example Trust Center IEEE address + ) + + @reply_to(c.NcpConfig.GetRxOnWhenIdle.Req(partial=True)) + def get_rx_on_when_idle(self, request): + """Handle get RX on when idle.""" + return c.NcpConfig.GetRxOnWhenIdle.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RxOnWhenIdle=1 # Example RxOnWhenIdle value + ) + + @reply_to(c.NWK.StartWithoutFormation.Req(partial=True)) + def start_without_formation(self, request): + """Handle start without formation.""" + return c.NWK.StartWithoutFormation.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK # Example status code + ) + + @reply_to(c.NcpConfig.GetModuleVersion.Req(partial=True)) + def get_module_version(self, request): + """Handle get module version.""" + return c.NcpConfig.GetModuleVersion.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, # Example status code + FWVersion=1, # Example firmware version + StackVersion=2, # Example stack version + ProtocolVersion=3 # Example protocol version + ) + + @reply_to(c.AF.SetSimpleDesc.Req(partial=True)) + def set_simple_desc(self, request): + """Handle set simple descriptor.""" + return c.AF.SetSimpleDesc.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK # Example status code + ) + + @reply_to(c.NcpConfig.GetEDTimeout.Req(partial=True)) + def get_ed_timeout(self, request): + """Handle get EndDevice timeout.""" + return c.NcpConfig.GetEDTimeout.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + Timeout=t.TimeoutIndex(0x01) # Example timeout value + ) + + @reply_to(c.NcpConfig.GetMaxChildren.Req(partial=True)) + def get_max_children(self, request): + """Handle get max children.""" + return c.NcpConfig.GetMaxChildren.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + ChildrenNbr=5 # Example max children + ) + + @reply_to(c.NcpConfig.GetAuthenticationStatus.Req(partial=True)) + def get_authentication_status(self, request): + """Handle get authentication status.""" + return c.NcpConfig.GetAuthenticationStatus.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + Authenticated=1 # Example authenticated value + ) + + @reply_to(c.NcpConfig.GetParentAddr.Req(partial=True)) + def get_parent_addr(self, request): + """Handle get parent address.""" + return c.NcpConfig.GetParentAddr.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + NWKParentAddr=t.NWK(0x1234) # Example parent NWK address + ) + + @reply_to(c.NcpConfig.GetCoordinatorVersion.Req(partial=True)) + def get_coordinator_version(self, request): + """Handle get coordinator version.""" + return c.NcpConfig.GetCoordinatorVersion.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + CoordinatorVersion=1 # Example coordinator version + ) + + def on_zdo_node_desc_req(self, req, NWKAddrOfInterest): + """Handle node description request.""" + if NWKAddrOfInterest != 0x0000: + return + + responses = [ + c.ZDO.NodeDescRsp.Callback( + Src=0x0000, + Status=t.ZDOStatus.SUCCESS, + NWK=0x0000, + NodeDescriptor=c.zdo.NullableNodeDescriptor( + byte1=0, + byte2=64, + mac_capability_flags=143, + manufacturer_code=0, + maximum_buffer_size=80, + maximum_incoming_transfer_size=160, + server_mask=1, # this differs + maximum_outgoing_transfer_size=160, + descriptor_capability_field=0, + ), + ), + ] + + if zdo_t.ZDOCmd.Node_Desc_rsp in self.zdo_callbacks: + responses.append( + c.ZDO.NodeDescReq.Callback( + Src=0x0000, + IsBroadcast=t.Bool.false, + ClusterId=zdo_t.ZDOCmd.Node_Desc_rsp, + SecurityUse=0, + TSN=req.TSN, + MacDst=0x0000, + Data=serialize_zdo_command( + command_id=zdo_t.ZDOCmd.Node_Desc_rsp, + Status=t.ZDOStatus.SUCCESS, + NWKAddrOfInterest=0x0000, + NodeDescriptor=zdo_t.NodeDescriptor( + **responses[0].NodeDescriptor.as_dict() + ), + ), + ) + ) + + return responses + + +class BaseZbossGenericDevice(BaseServerZBOSS): + """Base ZBOSS generic device.""" + + def __init__(self, *args, **kwargs): + """Init method.""" + super().__init__(*args, **kwargs) + self.active_endpoints = [] + self._nvram = {} + self._orig_nvram = {} + self.device_state = 0x00 + self.zdo_callbacks = set() + for name in dir(self): + func = getattr(self, name) + for req in getattr(func, "_reply_to", []): + self.reply_to(request=req, responses=[func]) + + def connection_lost(self, exc): + """Lost connection.""" + self.active_endpoints.clear() + return super().connection_lost(exc) + + @reply_to(c.NcpConfig.ReadNVRAM.Req(partial=True)) + def read_nvram(self, request): + """Handle NVRAM read.""" + status_code = t.StatusCodeGeneric.ERROR + if request.DatasetId == t.DatasetId.ZB_NVRAM_COMMON_DATA: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSCommonData( + byte_count=100, + bitfield=1, + depth=1, + nwk_manager_addr=0x0000, + panid=0x1234, + network_addr=0x5678, + channel_mask=t.Channels(14), + aps_extended_panid=t.EUI64.convert("00:11:22:33:44:55:66:77"), + nwk_extended_panid=t.EUI64.convert("00:11:22:33:44:55:66:77"), + parent_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + tc_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + nwk_key=t.KeyData(b'\x01' * 16), + nwk_key_seq=0, + tc_standard_key=t.KeyData(b'\x02' * 16), + channel=15, + page=0, + mac_interface_table=t.MacInterfaceTable( + bitfield_0=0, + bitfield_1=1, + link_pwr_data_rate=250, + channel_in_use=11, + supported_channels=t.Channels(15) + ), + reserved=0 + ) + nvram_version = 3 + dataset_version = 1 + elif request.DatasetId == t.DatasetId.ZB_IB_COUNTERS: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSIbCounters( + byte_count=8, + nib_counter=100, # Example counter value + aib_counter=50 # Example counter value + ) + nvram_version = 1 + dataset_version = 1 + elif request.DatasetId == t.DatasetId.ZB_NVRAM_ADDR_MAP: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSNwkAddrMap( + header=t.NwkAddrMapHeader( + byte_count=100, + entry_count=2, + _align=0 + ), + items=[ + t.NwkAddrMapRecord( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + nwk_addr=0x1234, + index=1, + redirect_type=0, + redirect_ref=0, + _align=0 + ), + t.NwkAddrMapRecord( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:78"), + nwk_addr=0x5678, + index=2, + redirect_type=0, + redirect_ref=0, + _align=0 + ) + ] + ) + nvram_version = 2 + dataset_version = 1 + elif request.DatasetId == t.DatasetId.ZB_NVRAM_APS_SECURE_DATA: + status_code = t.StatusCodeGeneric.OK + dataset = t.DSApsSecureKeys( + header=10, + items=[ + t.ApsSecureEntry( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:77"), + key=t.KeyData(b'\x03' * 16), + _unknown_1=0 + ), + t.ApsSecureEntry( + ieee_addr=t.EUI64.convert("00:11:22:33:44:55:66:78"), + key=t.KeyData(b'\x04' * 16), + _unknown_1=0 + ) + ] + ) + nvram_version = 1 + dataset_version = 1 + else: + dataset = t.NVRAMDataset(b'') + nvram_version = 1 + dataset_version = 1 + + return c.NcpConfig.ReadNVRAM.Rsp( + TSN=request.TSN, + StatusCat=t.StatusCategory(1), + StatusCode=status_code, + NVRAMVersion=nvram_version, + DatasetId=t.DatasetId(request.DatasetId), + DatasetVersion=dataset_version, + Dataset=t.NVRAMDataset(dataset.serialize()) + ) + + def on_zdo_node_desc_req(self, req, NWKAddrOfInterest): + """Handle node description request.""" + if NWKAddrOfInterest != 0x0000: + return + + responses = [ + c.ZDO.NodeDescRsp.Callback( + Src=0x0000, + Status=t.ZDOStatus.SUCCESS, + NWK=0x0000, + NodeDescriptor=c.zdo.NullableNodeDescriptor( + byte1=0, + byte2=64, + mac_capability_flags=143, + manufacturer_code=0, + maximum_buffer_size=80, + maximum_incoming_transfer_size=160, + server_mask=1, # this differs + maximum_outgoing_transfer_size=160, + descriptor_capability_field=0, + ), + ), + ] + + if zdo_t.ZDOCmd.Node_Desc_rsp in self.zdo_callbacks: + responses.append( + c.ZDO.NodeDescReq.Callback( + Src=0x0000, + IsBroadcast=t.Bool.false, + ClusterId=zdo_t.ZDOCmd.Node_Desc_rsp, + SecurityUse=0, + TSN=req.TSN, + MacDst=0x0000, + Data=serialize_zdo_command( + command_id=zdo_t.ZDOCmd.Node_Desc_rsp, + Status=t.ZDOStatus.SUCCESS, + NWKAddrOfInterest=0x0000, + NodeDescriptor=zdo_t.NodeDescriptor( + **responses[0].NodeDescriptor.as_dict() + ), + ), + ) + ) + + return responses diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 0000000..8939974 --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,493 @@ +"""Test commands.""" +import dataclasses +import keyword +from collections import defaultdict + +import pytest + +import zigpy_zboss.commands as c +from zigpy_zboss import types as t + + +def _validate_schema(schema): + """Validate the schema for command parameters.""" + for index, param in enumerate(schema): + assert isinstance(param.name, str) + assert param.name.isidentifier() + assert not keyword.iskeyword(param.name) + assert isinstance(param.type, type) + assert isinstance(param.description, str) + + # All optional params must be together at the end + if param.optional: + assert all(p.optional for p in schema[index:]) + + +def test_commands_schema(): + """Test the schema of all commands.""" + commands_by_id = defaultdict(list) + + for commands in c.ALL_COMMANDS: + for cmd in commands: + if cmd.definition.control_type == t.ControlType.REQ: + assert cmd.type == cmd.Req.header.control_type + assert cmd.Rsp.header.control_type == t.ControlType.RSP + + assert isinstance(cmd.Req.header, t.HLCommonHeader) + assert isinstance(cmd.Rsp.header, t.HLCommonHeader) + + assert cmd.Req.Rsp is cmd.Rsp + assert cmd.Rsp.Req is cmd.Req + assert cmd.Ind is None + + _validate_schema(cmd.Req.schema) + _validate_schema(cmd.Rsp.schema) + + commands_by_id[cmd.Req.header].append(cmd.Req) + commands_by_id[cmd.Rsp.header].append(cmd.Rsp) + + elif cmd.type == t.ControlType.IND: + assert cmd.Req is None + assert cmd.Rsp is None + + assert cmd.type == cmd.Ind.header.control_type + + assert cmd.Ind.header.control_type == t.ControlType.IND + + assert isinstance(cmd.Ind.header, t.HLCommonHeader) + + _validate_schema(cmd.Ind.schema) + + commands_by_id[cmd.Ind.header].append(cmd.Ind) + else: + assert False, "Command has unknown type" # noqa: B011 + + duplicate_commands = { + cmd: commands for cmd, + commands in commands_by_id.items() if len(commands) > 1 + } + assert not duplicate_commands + + assert len(commands_by_id.keys()) == len(c.COMMANDS_BY_ID.keys()) + + +def test_command_param_binding(): + """Test if commands match correctly.""" + # Example for GetModuleVersion which only requires TSN + c.NcpConfig.GetModuleVersion.Req(TSN=1) + + # Example for invalid param name + with pytest.raises(KeyError): + c.NcpConfig.GetModuleVersion.Rsp(asd=123) + + # Example for valid param name but incorrect value (invalid type) + with pytest.raises(ValueError): + c.NcpConfig.GetModuleVersion.Rsp(TSN="invalid", + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + + # Example for correct command invocation + valid_rsp = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + assert isinstance(valid_rsp.FWVersion, t.uint32_t) + assert isinstance(valid_rsp.StackVersion, t.uint32_t) + assert isinstance(valid_rsp.ProtocolVersion, t.uint32_t) + + # Example for checking overflow in integer type + with pytest.raises(ValueError): + c.NcpConfig.GetModuleVersion.Rsp(TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=10 ** 20, + StackVersion=789012, + ProtocolVersion=345678) + + # Invalid type in a parameter that expects a specific enum or struct + with pytest.raises(ValueError): + c.NcpConfig.SetZigbeeRole.Req(TSN=10, + DeviceRole="invalid type") + + # Coerced numerical type for a command expecting specific struct or uint + a = c.NcpConfig.SetZigbeeRole.Req(TSN=10, + DeviceRole=t.DeviceRole.ZR) + b = c.NcpConfig.SetZigbeeRole.Req(TSN=10, + DeviceRole=t.DeviceRole(1)) + + assert a == b + assert a.DeviceRole == b.DeviceRole + + assert ( + type(a.DeviceRole) == # noqa: E721 + type(b.DeviceRole) == t.DeviceRole # noqa: E721 + ) + + # Parameters can be looked up by name + zigbee_role = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole.ZC + ) + assert zigbee_role.DeviceRole == t.DeviceRole.ZC + + # Invalid ones cannot + with pytest.raises(AttributeError): + print(zigbee_role.Oops) + + +def test_command_optional_params(): + """Test optional parameters.""" + # Basic response with required parameters only + basic_ieee_addr_rsp = c.ZDO.IeeeAddrReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RemoteDevIEEE=t.EUI64([00, 11, 22, 33, 44, 55, 66, 77]), + RemoteDevNWK=t.NWK(0x1234) + ) + + # Full response including optional parameters + full_ieee_addr_rsp = c.ZDO.IeeeAddrReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RemoteDevIEEE=t.EUI64([00, 11, 22, 33, 44, 55, 66, 77]), + RemoteDevNWK=t.NWK(0x1234), + NumAssocDev=5, + StartIndex=0, + AssocDevNWKList=[t.NWK(0x0001), t.NWK(0x0002)] + ) + + basic_data = basic_ieee_addr_rsp.to_frame().hl_packet.data + full_data = full_ieee_addr_rsp.to_frame().hl_packet.data + + # Check if full data contains optional parameters + assert len(full_data) >= len(basic_data) + + # Basic data should be a prefix of full data + assert full_data.startswith(basic_data) + + # Deserialization checks + IeeeAddrReq = c.ZDO.IeeeAddrReq.Rsp + assert ( + IeeeAddrReq.from_frame(basic_ieee_addr_rsp.to_frame()) + == basic_ieee_addr_rsp + ) + assert ( + IeeeAddrReq.from_frame(full_ieee_addr_rsp.to_frame()) + == full_ieee_addr_rsp + ) + + +def test_command_optional_params_failures(): + """Test optional parameters failures.""" + with pytest.raises(KeyError): + # Optional params cannot be skipped over + c.ZDO.IeeeAddrReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RemoteDevIEEE=t.EUI64([00, 11, 22, 33, 44, 55, 66, 77]), + RemoteDevNWK=t.NWK(0x1234), + NumAssocDev=5, + # StartIndex=0, + AssocDevNWKList=[t.NWK(0x0001), t.NWK(0x0002)] + ) + + # Unless it's a partial command + partial = c.ZDO.IeeeAddrReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RemoteDevIEEE=t.EUI64([00, 11, 22, 33, 44, 55, 66, 77]), + RemoteDevNWK=t.NWK(0x1234), + NumAssocDev=5, + # StartIndex=0, + AssocDevNWKList=[t.NWK(0x0001), t.NWK(0x0002)], + partial=True + ) + + # In which case, it cannot be serialized + with pytest.raises(ValueError): + partial.to_frame() + + +def test_simple_descriptor(): + """Test simple descriptor.""" + lvlist16_type = t.LVList[t.uint16_t] + + simple_descriptor = t.SimpleDescriptor() + simple_descriptor.endpoint = t.uint8_t(1) + simple_descriptor.profile = t.uint16_t(260) + simple_descriptor.device_type = t.uint16_t(257) + simple_descriptor.device_version = t.uint8_t(0) + simple_descriptor.input_clusters = lvlist16_type( + [0, 3, 4, 5, 6, 8, 2821, 1794] + ) + simple_descriptor.output_clusters_count = t.uint8_t(2) + simple_descriptor.input_clusters_count = t.uint8_t(8) + simple_descriptor.output_clusters = lvlist16_type([0x0001, 0x0002]) + + c1 = c.ZDO.SimpleDescriptorReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + SimpleDesc=simple_descriptor, + NwkAddr=t.NWK(0x1234) + ) + + sp_simple_descriptor = t.SimpleDescriptor() + sp_simple_descriptor.endpoint = t.uint8_t(1) + sp_simple_descriptor.profile = t.uint16_t(260) + sp_simple_descriptor.device_type = t.uint16_t(257) + sp_simple_descriptor.device_version = t.uint8_t(0) + sp_simple_descriptor.input_clusters = lvlist16_type( + [0, 3, 4, 5, 6, 8, 2821, 1794] + ) + sp_simple_descriptor.output_clusters_count = t.uint8_t(2) + sp_simple_descriptor.input_clusters_count = t.uint8_t(8) + sp_simple_descriptor.output_clusters = lvlist16_type([0x0001, 0x0002]) + + c2 = c.ZDO.SimpleDescriptorReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + SimpleDesc=sp_simple_descriptor, + NwkAddr=t.NWK(0x1234) + ) + + assert c1.to_frame() == c2.to_frame() + # assert c1 == c2 + + +def test_command_str_repr(): + """Test __str__ and __repr__ methods for commands.""" + command = c.NcpConfig.GetModuleVersion.Req(TSN=1) + + assert repr(command) == str(command) + assert str([command]) == f"[{str(command)}]" + + +def test_command_immutability(): + """Test that commands are immutable.""" + command1 = c.ZDO.IeeeAddrReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RemoteDevNWK=t.NWK(0x1234), + NumAssocDev=5, + StartIndex=0, + partial=True + ) + + command2 = c.ZDO.IeeeAddrReq.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + RemoteDevNWK=t.NWK(0x1234), + NumAssocDev=5, + StartIndex=0, + partial=True + ) + + d = {command1: True} + + assert command1 == command2 + assert command2 in d + assert {command1: True} == {command2: True} + + with pytest.raises(RuntimeError): + command1.partial = False + + with pytest.raises(RuntimeError): + command1.StatusCode = t.StatusCodeGeneric.OK + + with pytest.raises(RuntimeError): + command1.NumAssocDev = 5 + + with pytest.raises(RuntimeError): + del command1.StartIndex + + assert command1 == command2 + + +def test_command_serialization(): + """Test command serialization.""" + command = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + frame = command.to_frame() + + assert frame.hl_packet.data == bytes.fromhex( + "0A010040E20100140A0C004E460500" + ) + + # Partial frames cannot be serialized + with pytest.raises(ValueError): + partial1 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + # StackVersion=789012, + ProtocolVersion=345678, + partial=True + ) + + partial1.to_frame() + + # Partial frames cannot be serialized, even if all params are filled out + with pytest.raises(ValueError): + partial2 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678, + partial=True + ) + partial2.to_frame() + + +def test_command_equality(): + """Test command equality.""" + command1 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + + command2 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + + command3 = c.NcpConfig.GetModuleVersion.Rsp( + TSN=20, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + + assert command1 == command1 + assert command1.matches(command1) + assert command2 == command1 + assert command1 == command2 + + assert command1 != command3 + assert command3 != command1 + + assert command1.matches(command2) # Matching is a superset of equality + assert command2.matches(command1) + assert not command1.matches(command3) + assert not command3.matches(command1) + + assert not command1.matches( + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ) + ) + assert c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + partial=True + ).matches(command1) + + # parameters can be specified explicitly as None + assert c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + StackVersion=None, + partial=True + ).matches(command1) + assert c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + StackVersion=789012, + partial=True + ).matches(command1) + assert not c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + StackVersion=79000, + partial=True + ).matches(command1) + + # Different frame types do not match, even if they have the same structure + assert not c.ZDO.MgtLeave.Rsp(TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK).matches( + c.ZDO.PermitJoin.Rsp(partial=True) + ) + + +def test_command_deserialization(): + """Test command deserialization.""" + command = c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=123456, + StackVersion=789012, + ProtocolVersion=345678 + ) + + assert type(command).from_frame(command.to_frame()) == command + assert ( + command.to_frame() == + type(command).from_frame(command.to_frame()).to_frame() + ) + + # Deserialization fails if there is unparsed data at the end of the frame + frame = command.to_frame() + new_hl_packet = dataclasses.replace( + frame.hl_packet, data=frame.hl_packet.data + b"\x01" + ) + + # Create a new Frame instance with the updated hl_packet + bad_frame = dataclasses.replace(frame, hl_packet=new_hl_packet) + + with pytest.raises(ValueError): + type(command).from_frame(bad_frame) + + # Deserialization fails if you attempt to deserialize the wrong frame + with pytest.raises(ValueError): + c.ZDO.MgtLeave.Rsp(TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK).from_frame( + c.ZDO.PermitJoin.Rsp(TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK).to_frame() + ) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..e96a502 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,60 @@ +"""Test config.""" +import pytest +from voluptuous import Invalid + +import zigpy_zboss.config as conf + + +def test_pin_states_same_lengths(): + """Test same lengths pin states.""" + # Bare schema works + conf.CONFIG_SCHEMA( + { + conf.CONF_DEVICE: {conf.CONF_DEVICE_PATH: "/dev/null"}, + } + ) + + # So does one with explicitly specified pin states + config = conf.CONFIG_SCHEMA( + { + conf.CONF_DEVICE: {conf.CONF_DEVICE_PATH: "/dev/null"}, + conf.CONF_ZBOSS_CONFIG: { + conf.CONF_CONNECT_RTS_STATES: ["on", True, 0, 0, 0, 1, 1], + conf.CONF_CONNECT_DTR_STATES: ["off", False, 1, 0, 0, 1, 1], + }, + } + ) + + assert config[conf.CONF_ZBOSS_CONFIG][conf.CONF_CONNECT_RTS_STATES] == [ + True, + True, + False, + False, + False, + True, + True, + ] + assert config[conf.CONF_ZBOSS_CONFIG][conf.CONF_CONNECT_DTR_STATES] == [ + False, + False, + True, + False, + False, + True, + True, + ] + + +def test_pin_states_different_lengths(): + """Test different lengths pin states.""" + # They must be the same length + with pytest.raises(Invalid): + conf.CONFIG_SCHEMA( + { + conf.CONF_DEVICE: {conf.CONF_DEVICE_PATH: "/dev/null"}, + conf.CONF_ZBOSS_CONFIG: { + conf.CONF_CONNECT_RTS_STATES: [1, 1, 0], + conf.CONF_CONNECT_DTR_STATES: [1, 1], + }, + } + ) diff --git a/tests/test_frame.py b/tests/test_frame.py new file mode 100644 index 0000000..d4397ae --- /dev/null +++ b/tests/test_frame.py @@ -0,0 +1,241 @@ +"""Test frame.""" +import pytest + +import zigpy_zboss.types as t +from zigpy_zboss.frames import (CRC8, ZBNCP_LL_BODY_SIZE_MAX, Frame, HLPacket, + InvalidFrame, LLHeader) + + +def test_frame_deserialization(): + """Test frame deserialization.""" + ll_signature = t.uint16_t(0xADDE).serialize() + + # Create an HLCommonHeader with specific fields + hl_data = t.Bytes(b"test_data").serialize() + hl_packet = hl_data + + ll_size = t.uint16_t(len(hl_packet) + 5).serialize() + ll_type = t.uint8_t(0x01).serialize() + ll_flags = t.LLFlags(0x00).serialize() + + ll_header_without_crc = ll_signature + ll_size + ll_type + ll_flags + ll_crc = CRC8(ll_header_without_crc[2:6]).digest().serialize() + ll_header = ll_header_without_crc + ll_crc + + frame_data = ll_header + hl_packet + extra_data = b"extra_data" + + # Deserialize frame + frame, rest = Frame.deserialize(frame_data + extra_data) + + # Assertions + assert rest == extra_data + assert frame.ll_header.signature == 0xADDE + assert frame.ll_header.size == len(hl_packet) + 5 + assert frame.ll_header.frame_type == 0x01 + assert frame.ll_header.flags == 0x00 + assert frame.ll_header.crc8 == CRC8(ll_header_without_crc[2:6]).digest() + assert frame.hl_packet.data == b"test_data" + + # Invalid frame signature + invalid_signature_frame_data = t.uint16_t(0xFFFF).serialize() + frame_data[ + 2:] + with pytest.raises(InvalidFrame, + match="Expected frame to start with Signature"): + Frame.deserialize(invalid_signature_frame_data) + + # Invalid CRC8 + ll_header = ll_header_without_crc + + frame_data_without_crc = ll_header + hl_packet + with pytest.raises(InvalidFrame, match="Invalid frame checksum"): + Frame.deserialize(frame_data_without_crc) + + +def test_ack_flag_deserialization(): + """Test frame deserialization with ACK flag.""" + ll_signature = t.uint16_t(0xADDE).serialize() + ll_size = t.uint16_t(5).serialize() # Only LLHeader size + ll_type = t.uint8_t(0x01).serialize() + ll_flags = t.LLFlags(t.LLFlags.isACK).serialize() + + ll_header_without_crc = ll_signature + ll_size + ll_type + ll_flags + ll_crc = CRC8(ll_header_without_crc[2:6]).digest().serialize() + ll_header = ll_header_without_crc + ll_crc + + frame_data = ll_header + extra_data = b"extra_data" + + frame, rest = Frame.deserialize(frame_data + extra_data) + + assert rest == extra_data + assert frame.ll_header.signature == 0xADDE + assert frame.ll_header.size == 5 + assert frame.ll_header.frame_type == 0x01 + assert frame.ll_header.flags == t.LLFlags.isACK + assert frame.ll_header.crc8 == CRC8(ll_header_without_crc[2:6]).digest() + assert frame.hl_packet is None + + +def test_first_frag_flag_deserialization(): + """Test frame deserialization with FirstFrag flag.""" + ll_signature = t.uint16_t(0xADDE).serialize() + + # Create an HLCommonHeader with specific fields + hl_header = t.HLCommonHeader( + version=0x01, type=t.ControlType.RSP, id=0x1234 + ) + hl_data = t.Bytes(b"test_data") + + # Create HLPacket and serialize + hl_packet = HLPacket(header=hl_header, data=hl_data) + hl_packet_data = hl_packet.serialize() + + # Create LLHeader with FirstFrag flag + ll_size = t.uint16_t(len(hl_packet_data) + 5).serialize() + ll_type = t.uint8_t(0x01).serialize() + ll_flags = t.LLFlags(t.LLFlags.FirstFrag).serialize() + + ll_header_without_crc = ll_signature + ll_size + ll_type + ll_flags + ll_crc = CRC8(ll_header_without_crc[2:6]).digest().serialize() + ll_header = ll_header_without_crc + ll_crc + + frame_data = ll_header + hl_packet_data + extra_data = b"extra_data" + + frame, rest = Frame.deserialize(frame_data + extra_data) + + assert rest == extra_data + assert frame.ll_header.signature == 0xADDE + assert frame.ll_header.size == len(hl_packet_data) + 5 + assert frame.ll_header.frame_type == 0x01 + assert frame.ll_header.flags == t.LLFlags.FirstFrag + assert frame.ll_header.crc8 == CRC8(ll_header_without_crc[2:6]).digest() + assert frame.hl_packet.header.version == 0x01 + assert frame.hl_packet.header.control_type == t.ControlType.RSP + assert frame.hl_packet.header.id == 0x1234 + assert frame.hl_packet.data == b"test_data" + + +def test_handle_tx_fragmentation(): + """Test the handle_tx_fragmentation method for proper fragmentation.""" + # Create an HLCommonHeader with specific fields + hl_header = t.HLCommonHeader( + version=0x01, type=t.ControlType.RSP, id=0x1234 + ) + large_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX * 2 + 50) + hl_data = t.Bytes(large_data) + + # Create an HLPacket with the large data + hl_packet = HLPacket(header=hl_header, data=hl_data) + frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet) + + fragments = frame.handle_tx_fragmentation() + + total_fragments = frame.count_fragments() + assert len(fragments) == total_fragments + + # Calculate the expected size of the first fragment + # Exclude the CRC16 for size calculation + serialized_hl_packet = hl_packet.serialize()[2:] + first_frag_size = ( + len(serialized_hl_packet) % ZBNCP_LL_BODY_SIZE_MAX + or ZBNCP_LL_BODY_SIZE_MAX + ) + + # Check the first fragment + first_fragment = fragments[0] + assert first_fragment.ll_header.flags == t.LLFlags.FirstFrag + assert first_fragment.ll_header.size == first_frag_size + 7 + assert len(first_fragment.hl_packet.data) == first_frag_size - 4 + + # Check the middle fragments + for middle_fragment in fragments[1:-1]: + assert middle_fragment.ll_header.flags == 0 + assert middle_fragment.ll_header.size == ZBNCP_LL_BODY_SIZE_MAX + 7 + assert len(middle_fragment.hl_packet.data) == ZBNCP_LL_BODY_SIZE_MAX + + # Check the last fragment + last_fragment = fragments[-1] + last_frag_size = ( + len(serialized_hl_packet) - + (first_frag_size + (total_fragments - 2) * ZBNCP_LL_BODY_SIZE_MAX) + ) + assert last_fragment.ll_header.flags == t.LLFlags.LastFrag + assert last_fragment.ll_header.size == last_frag_size + 7 + assert len(last_fragment.hl_packet.data) == last_frag_size + + +def test_handle_tx_fragmentation_edge_cases(): + """Test the handle_tx_fragmentation method for various edge cases.""" + # Data size exactly equal to ZBNCP_LL_BODY_SIZE_MAX + exact_size_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX - 2 - 2) + hl_header = t.HLCommonHeader(version=0x01, type=t.ControlType.RSP, + id=0x1234) + hl_packet = HLPacket(header=hl_header, data=t.Bytes(exact_size_data)) + frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet) + + # Perform fragmentation + fragments = frame.handle_tx_fragmentation() + assert len(fragments) == 1 # Should not fragment + + # Test with data size just above ZBNCP_LL_BODY_SIZE_MAX + just_above_size_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX + 1 - 2 - 2) + hl_packet = HLPacket(header=hl_header, data=t.Bytes(just_above_size_data)) + frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet) + fragments = frame.handle_tx_fragmentation() + assert len(fragments) == 2 # Should fragment into two + + # Test with data size much larger than ZBNCP_LL_BODY_SIZE_MAX + large_data = b"a" * ((ZBNCP_LL_BODY_SIZE_MAX * 5) + 50 - 2 - 2) + hl_packet = HLPacket(header=hl_header, data=t.Bytes(large_data)) + frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet) + fragments = frame.handle_tx_fragmentation() + assert len(fragments) == 6 # 5 full fragments and 1 partial fragment + + # Test with very small data + small_data = b"a" * 10 + hl_packet = HLPacket(header=hl_header, data=t.Bytes(small_data)) + frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet) + fragments = frame.handle_tx_fragmentation() + assert len(fragments) == 1 # Should not fragment + + +def test_handle_rx_fragmentation(): + """Test the handle_rx_fragmentation method for. + + proper reassembly of fragments. + """ + # Create an HLCommonHeader with specific fields + hl_header = t.HLCommonHeader( + version=0x01, type=t.ControlType.RSP, id=0x1234 + ) + large_data = b"a" * (ZBNCP_LL_BODY_SIZE_MAX * 2 + 50) + hl_data = t.Bytes(large_data) + + # Create an HLPacket with the large data + hl_packet = HLPacket(header=hl_header, data=hl_data) + frame = Frame(ll_header=LLHeader(), hl_packet=hl_packet) + + # Perform fragmentation + fragments = frame.handle_tx_fragmentation() + + # Verify that the correct number of fragments was created + total_fragments = frame.count_fragments() + assert len(fragments) == total_fragments + + # Reassemble the fragments using handle_rx_fragmentation + reassembled_frame = Frame.handle_rx_fragmentation(fragments) + + # Verify the reassembled frame + assert ( + reassembled_frame.ll_header.frame_type == t.TYPE_ZBOSS_NCP_API_HL + ) + assert ( + reassembled_frame.ll_header.flags == + (t.LLFlags.FirstFrag | t.LLFlags.LastFrag) + ) + + # Verify the reassembled data matches the original data + reassembled_data = reassembled_frame.hl_packet.data + assert reassembled_data == large_data diff --git a/tests/test_nvids.py b/tests/test_nvids.py new file mode 100644 index 0000000..015b576 --- /dev/null +++ b/tests/test_nvids.py @@ -0,0 +1,135 @@ +"""Test NVIDS.""" +from struct import pack + +import zigpy_zboss.types as t +from zigpy_zboss.types import nvids +from zigpy_zboss.types.nvids import (ApsSecureEntry, DSApsSecureKeys, + DSNwkAddrMap, NwkAddrMapHeader, + NwkAddrMapRecord) + + +def test_nv_ram_get_byte_size(): + """Test the get_byte_size method of the NVRAMStruct class.""" + class TestStruct(nvids.NVRAMStruct): + a: t.uint8_t + b: t.EUI64 + c: t.uint8_t + + data = TestStruct(a=1, b=[2], c=3) + + byte_size = data.get_byte_size() + + assert byte_size == 10, f"Expected byte size to be 10, but got {byte_size}" + + +def test_dsapssecurekeys(): + """Test the serialize/deserialize method of the DSApsSecureKeys class.""" + ieee_addr1 = t.EUI64([0, 1, 2, 3, 4, 5, 6, 7]) + key1 = t.KeyData([0x10] * 16) + unknown_1_1 = t.basic.uint32_t(12345678) + entry1 = ApsSecureEntry( + ieee_addr=ieee_addr1, key=key1, _unknown_1=unknown_1_1 + ) + entry_data1 = entry1.serialize() + + ieee_addr2 = t.EUI64([8, 9, 10, 11, 12, 13, 14, 15]) + key2 = t.KeyData([0x20] * 16) + unknown_1_2 = t.basic.uint32_t(87654321) + entry2 = ApsSecureEntry( + ieee_addr=ieee_addr2, key=key2, _unknown_1=unknown_1_2 + ) + entry_data2 = entry2.serialize() + + # Calculate total length for the LVList + entry_size = ApsSecureEntry.get_byte_size() + total_length = (entry_size * 2) + 4 + + length_bytes = pack(" int: + return self.method() + + CONSTANT1 = 1 + constant2 = "foo" + _constant3 = "bar" + + assert len(TestStruct.fields) == 2 + assert TestStruct.fields.a == t.CStructField(name="a", type=t.uint8_t) + assert TestStruct.fields.b == t.CStructField(name="b", type=t.uint16_t) + + assert TestStruct.CONSTANT1 == 1 + assert TestStruct.constant2 == "foo" + assert TestStruct._constant3 == "bar" + + assert TestStruct(a=1, b=2).method() == 3 + + +def test_struct_nesting(): + """Test struct nesting.""" + class Outer(t.CStruct): + e: t.uint32_t + + class TestStruct(t.CStruct): + class Inner(t.CStruct): + c: t.uint16_t + + a: t.uint8_t + b: Inner + d: Outer + + assert len(TestStruct.fields) == 3 + assert TestStruct.fields.a == t.CStructField(name="a", type=t.uint8_t) + assert TestStruct.fields.b == t.CStructField( + name="b", type=TestStruct.Inner + ) + assert TestStruct.fields.d == t.CStructField(name="d", type=Outer) + + assert len(TestStruct.Inner.fields) == 1 + assert TestStruct.Inner.fields.c == t.CStructField( + name="c", type=t.uint16_t + ) + + struct = TestStruct(a=1, b=TestStruct.Inner(c=2), d=Outer(e=3)) + assert struct.a == 1 + assert struct.b.c == 2 + assert struct.d.e == 3 + + +def test_struct_aligned_serialization_deserialization(): + """Test struct aligned serialization/deserialization.""" + class TestStruct(t.CStruct): + a: t.uint8_t + # One padding byte here + b: t.uint16_t + # No padding here + c: t.uint32_t # largest type, so the struct is 32 bit aligned + d: t.uint8_t + # Three padding bytes here + e: t.uint32_t + f: t.uint8_t + # Three more to make the struct 32 bit aligned + + assert TestStruct.get_alignment(align=False) == 1 + assert TestStruct.get_alignment(align=True) == 32 // 8 + assert TestStruct.get_size(align=False) == (1 + 2 + 4 + 1 + 4 + 1) + assert TestStruct.get_size(align=True) == ( + 1 + 2 + 4 + 1 + 4 + 1 + ) + (1 + 3 + 3) + + expected = b"" + expected += t.uint8_t(1).serialize() + expected += b"\xFF" + t.uint16_t(2).serialize() + expected += t.uint32_t(3).serialize() + expected += t.uint8_t(4).serialize() + expected += b"\xFF\xFF\xFF" + t.uint32_t(5).serialize() + expected += t.uint8_t(6).serialize() + expected += b"\xFF\xFF\xFF" + + struct = TestStruct(a=1, b=2, c=3, d=4, e=5, f=6) + assert struct.serialize(align=True) == expected + + struct2, remaining = TestStruct.deserialize(expected + b"test", align=True) + assert remaining == b"test" + assert struct == struct2 + + with pytest.raises(ValueError): + TestStruct.deserialize(expected[:-1], align=True) + + +def test_struct_aligned_nested_serialization_deserialization(): + """Test structed alined nested serialization/deserialization.""" + class Inner(t.CStruct): + _padding_byte = b"\xCD" + + c: t.uint8_t + d: t.uint32_t + e: t.uint8_t + + class TestStruct(t.CStruct): + _padding_byte = b"\xAB" + + a: t.uint8_t + b: Inner + f: t.uint16_t + + expected = b"" + expected += t.uint8_t(1).serialize() + + # Inner struct + expected += b"\xAB\xAB\xAB" + t.uint8_t(2).serialize() + expected += b"\xCD\xCD\xCD" + t.uint32_t(3).serialize() + expected += t.uint8_t(4).serialize() + expected += b"\xCD\xCD\xCD" # Aligned to 4 bytes + + expected += t.uint16_t(5).serialize() + expected += b"\xAB\xAB" # Also aligned to 4 bytes due to inner struct + + struct = TestStruct(a=1, b=Inner(c=2, d=3, e=4), f=5) + assert struct.serialize(align=True) == expected + + struct2, remaining = TestStruct.deserialize(expected + b"test", align=True) + assert remaining == b"test" + assert struct == struct2 + + +def test_struct_unaligned_serialization_deserialization(): + """Test struct unaligned serialization/deserialization.""" + class TestStruct(t.CStruct): + a: t.uint8_t + b: t.uint16_t + c: t.uint32_t + d: t.uint8_t + e: t.uint32_t + f: t.uint8_t + + expected = b"" + expected += t.uint8_t(1).serialize() + expected += t.uint16_t(2).serialize() + expected += t.uint32_t(3).serialize() + expected += t.uint8_t(4).serialize() + expected += t.uint32_t(5).serialize() + expected += t.uint8_t(6).serialize() + + struct = TestStruct(a=1, b=2, c=3, d=4, e=5, f=6) + + assert struct.serialize(align=False) == expected + + struct2, remaining = TestStruct.deserialize( + expected + b"test", align=False + ) + assert remaining == b"test" + assert struct == struct2 + + with pytest.raises(ValueError): + TestStruct.deserialize(expected[:-1], align=False) + + +def test_struct_equality(): + """Test struct equality.""" + class InnerStruct(t.CStruct): + c: t.EUI64 + + class TestStruct(t.CStruct): + a: t.uint8_t + b: InnerStruct + + class TestStruct2(t.CStruct): + a: t.uint8_t + b: InnerStruct + + s1 = TestStruct( + a=2, b=InnerStruct(c=t.EUI64.convert("00:00:00:00:00:00:00:00")) + ) + s2 = TestStruct( + a=2, b=InnerStruct(c=t.EUI64.convert("00:00:00:00:00:00:00:00")) + ) + s3 = TestStruct2( + a=2, b=InnerStruct(c=t.EUI64.convert("00:00:00:00:00:00:00:00")) + ) + + assert s1 == s2 + assert s1.replace(a=3) != s1 + assert s1.replace(a=3).replace(a=2) == s1 + + assert s1 != s3 + assert s1.serialize() == s3.serialize() + + assert TestStruct(s1) == s1 + assert TestStruct(a=s1.a, b=s1.b) == s1 + + with pytest.raises(ValueError): + TestStruct(s1, b=InnerStruct(s1.b)) + + with pytest.raises(ValueError): + TestStruct2(s1) + + +def test_struct_repr(): + """Test struct representation.""" + class TestStruct(t.CStruct): + a: t.uint8_t + b: t.uint32_t + + assert str(TestStruct(a=1, b=2)) == "TestStruct(a=1, b=2)" + assert str([TestStruct(a=1, b=2)]) == "[TestStruct(a=1, b=2)]" + + +def test_struct_bad_fields(): + """Test struct bad fields.""" + with pytest.raises(TypeError): + + class TestStruct(t.CStruct): + a: t.uint8_t + b: int + + +def test_struct_incomplete_serialization(): + """Test struct incomplete serialization.""" + class TestStruct(t.CStruct): + a: t.uint8_t + b: t.uint8_t + + TestStruct(a=1, b=2).serialize() + + with pytest.raises(ValueError): + TestStruct(a=1, b=None).serialize() + + with pytest.raises(ValueError): + TestStruct(a=1).serialize() + + struct = TestStruct(a=1, b=2) + struct.b = object() + + with pytest.raises(ValueError): + struct.serialize() diff --git a/tests/test_types_named.py b/tests/test_types_named.py new file mode 100644 index 0000000..d0e83d5 --- /dev/null +++ b/tests/test_types_named.py @@ -0,0 +1,34 @@ +"""Test named types.""" +import zigpy_zboss.types as t + + +def test_channel_entry(): + """Test channel entry. + + ChannelEntry class for proper serialization, + deserialization, equality, and representation. + """ + # Sample data for testing + page_data = b"\x01" # Page number as bytes + channel_mask_data = b"\x00\x10\x00\x00" # Sample channel mask as bytes + + data = page_data + channel_mask_data + + # Test deserialization + channel_entry, remaining_data = t.ChannelEntry.deserialize(data) + assert remaining_data == b'' # no extra data should remain + assert channel_entry.page == 1 + assert channel_entry.channel_mask == 0x00001000 + + # Test serialization + assert channel_entry.serialize() == data + + # Test equality + another_entry = t.ChannelEntry(page=1, channel_mask=0x00001000) + assert channel_entry == another_entry + assert channel_entry != t.ChannelEntry(page=0, channel_mask=0x00002000) + + # Test __repr__ + expected_repr = \ + "ChannelEntry(page=1, channel_mask=)" + assert repr(channel_entry) == expected_repr diff --git a/tests/test_uart.py b/tests/test_uart.py new file mode 100644 index 0000000..e151774 --- /dev/null +++ b/tests/test_uart.py @@ -0,0 +1,293 @@ +"""Test uart.""" +import pytest +from serial_asyncio import SerialTransport + +import zigpy_zboss.commands as c +import zigpy_zboss.config as conf +import zigpy_zboss.types as t +from zigpy_zboss import uart as zboss_uart +from zigpy_zboss.checksum import CRC8 +from zigpy_zboss.frames import Frame + + +@pytest.fixture +def connected_uart(mocker): + """Uart connected fixture.""" + zboss = mocker.Mock() + config = { + conf.CONF_DEVICE_PATH: "/dev/ttyFAKE0", + conf.CONF_DEVICE_BAUDRATE: 115200, + conf.CONF_DEVICE_FLOW_CONTROL: None} + + uart = zboss_uart.ZbossNcpProtocol(config, zboss) + uart.connection_made(mocker.Mock()) + + yield zboss, uart + + +def ll_checksum(frame): + """Return frame with new crc8 checksum calculation.""" + crc = CRC8(frame.ll_header.serialize()[2:6]).digest() + frame.ll_header = frame.ll_header.with_crc8(crc) + return frame + + +@pytest.fixture +def dummy_serial_conn(event_loop, mocker): + """Connect serial dummy.""" + device = "/dev/ttyACM0" + + serial_interface = mocker.Mock() + serial_interface.name = device + + def create_serial_conn(loop, protocol_factory, url, *args, **kwargs): + fut = event_loop.create_future() + assert url == device + + protocol = protocol_factory() + + # Our event loop doesn't really do anything + event_loop.add_writer = lambda *args, **kwargs: None + event_loop.add_reader = lambda *args, **kwargs: None + event_loop.remove_writer = lambda *args, **kwargs: None + event_loop.remove_reader = lambda *args, **kwargs: None + + transport = SerialTransport(event_loop, protocol, serial_interface) + + protocol.connection_made(transport) + + fut.set_result((transport, protocol)) + + return fut + + mocker.patch( + "zigpy.serial.pyserial_asyncio.create_serial_connection", + new=create_serial_conn + ) + + return device, serial_interface + + +def test_uart_rx_basic(connected_uart): + """Test UART basic receive.""" + zboss, uart = connected_uart + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + uart.data_received(test_frame_bytes) + + zboss.frame_received.assert_called_once_with(test_frame) + + +def test_uart_str_repr(connected_uart): + """Test uart representation.""" + zboss, uart = connected_uart + + str(uart) + repr(uart) + + +def test_uart_rx_byte_by_byte(connected_uart): + """Test uart RX byte by byte.""" + zboss, uart = connected_uart + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + for byte in test_frame_bytes: + uart.data_received(bytes([byte])) + + zboss.frame_received.assert_called_once_with(test_frame) + + +def test_uart_rx_byte_by_byte_garbage(connected_uart): + """Test uart RX byte by byte garbage.""" + zboss, uart = connected_uart + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + data = b"" + data += bytes.fromhex("58 4a 72 35 51 da 60 ed 1f") + data += bytes.fromhex("03 6d b6") + data += bytes.fromhex("ee 90") + data += test_frame_bytes + data += bytes.fromhex("00 00") + data += bytes.fromhex("e4 4f 51 b2 39 4b 8d e3 ca 61") + data += bytes.fromhex("8c 56 8a 2c d8 22 64 9e 9d 7b") + + # The frame should be parsed identically regardless of framing + for byte in data: + uart.data_received(bytes([byte])) + + zboss.frame_received.assert_called_once_with(test_frame) + + +def test_uart_rx_big_garbage(connected_uart): + """Test uart RX big garbage.""" + zboss, uart = connected_uart + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + data = b"" + data += bytes.fromhex("58 4a 72 35 51 da 60 ed 1f") + data += bytes.fromhex("03 6d b6") + data += bytes.fromhex("ee 90") + data += test_frame_bytes + data += bytes.fromhex("00 00") + data += bytes.fromhex("e4 4f 51 b2 39 4b 8d e3 ca 61") + data += bytes.fromhex("8c 56 8a 2c d8 22 64 9e 9d 7b") + + # The frame should be parsed identically regardless of framing + uart.data_received(data) + + zboss.frame_received.assert_called_once_with(test_frame) + + +def test_uart_rx_corrupted_fcs(connected_uart): + """Test uart RX corrupted.""" + zboss, uart = connected_uart + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + # Almost, but not quite + uart.data_received(test_frame_bytes[:-1]) + uart.data_received(b"\x00") + + assert not zboss.frame_received.called + + +def test_uart_rx_sof_stress(connected_uart): + """Test uart RX signature stress.""" + zboss, uart = connected_uart + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + # We include an almost-valid frame and many stray SoF markers + uart.data_received( + b"\xFE" + b"\xFE" + b"\xFE" + test_frame_bytes[:-1] + b"\x00" + ) + uart.data_received(b"\xFE\xFE\x00\xFE\x01") + uart.data_received( + b"\xFE" + b"\xFE" + b"\xFE" + test_frame_bytes + b"\x00\x00" + ) + + # We should see the valid frame exactly once + zboss.frame_received.assert_called_once_with(test_frame) + + +def test_uart_frame_received_error(connected_uart, mocker): + """Test uart frame received error.""" + zboss, uart = connected_uart + zboss.frame_received = mocker.Mock(side_effect=RuntimeError("An error")) + + test_command = c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ) + test_frame = test_command.to_frame() + test_frame = ll_checksum(test_frame) + test_frame_bytes = Frame( + test_frame.ll_header, test_frame.hl_packet + ).serialize() + + # Errors thrown by zboss.frame_received should + # not impact how many frames are handled + uart.data_received(test_frame_bytes * 3) + + # We should have received all three frames + assert zboss.frame_received.call_count == 3 + + +@pytest.mark.asyncio +async def test_connection_lost(dummy_serial_conn, mocker, event_loop): + """Test connection lost.""" + device, _ = dummy_serial_conn + + zboss = mocker.Mock() + conn_lost_fut = event_loop.create_future() + zboss.connection_lost = conn_lost_fut.set_result + + protocol = await zboss_uart.connect( + conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: device}), api=zboss + ) + + exception = RuntimeError("Uh oh, something broke") + protocol.connection_lost(exception) + + # Losing a connection propagates up to the ZBOSS object + assert (await conn_lost_fut) == exception + + +# ToFix: this is not testing the uart test_connection_made method +# @pytest.mark.asyncio +# async def test_connection_made(dummy_serial_conn, mocker): +# """Test connection made.""" +# device, _ = dummy_serial_conn +# zboss = mocker.Mock() + +# await zboss_uart.connect( +# conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: device}), api=zboss +# ) + +# zboss._uart.connection_made.assert_called_once_with() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..daade53 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,107 @@ +"""Test utils.""" +import zigpy_zboss.commands as c +import zigpy_zboss.types as t +from zigpy_zboss.utils import deduplicate_commands + + +def test_command_deduplication_simple(): + """Test command deduplication simple.""" + c1 = c.NcpConfig.GetModuleVersion.Req(TSN=10) + c2 = c.NcpConfig.NCPModuleReset.Req(TSN=10, Option=t.ResetOptions(0)) + + assert deduplicate_commands([]) == () + assert deduplicate_commands([c1]) == (c1,) + assert deduplicate_commands([c1, c1]) == (c1,) + assert deduplicate_commands([c1, c2]) == (c1, c2) + assert deduplicate_commands([c2, c1, c2]) == (c2, c1) + + +def test_command_deduplication_complex(): + """Test command deduplication complex.""" + result = deduplicate_commands( + [ + c.NcpConfig.GetModuleVersion.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + FWVersion=1, + StackVersion=2, + ProtocolVersion=3, + ), + # Duplicating matching commands shouldn't do anything + c.NcpConfig.GetModuleVersion.Rsp(partial=True), + c.NcpConfig.GetModuleVersion.Rsp(partial=True), + # Matching against different command types should also work + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(2) + ), + c.NcpConfig.GetNwkKeys.Rsp( + partial=True, + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + KeyNumber1=10, + ), + c.NcpConfig.GetNwkKeys.Rsp( + partial=True, + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + KeyNumber1=10, + KeyNumber2=20, + ), + c.NcpConfig.GetNwkKeys.Rsp( + partial=True, + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + KeyNumber1=10, + KeyNumber2=20, + KeyNumber3=30, + ), + c.NcpConfig.GetNwkKeys.Rsp( + partial=True, + TSN=11, + StatusCat=t.StatusCategory(2), + KeyNumber3=30, + ), + ] + ) + + assert set(result) == { + c.NcpConfig.GetModuleVersion.Rsp(partial=True), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=10, + StatusCat=t.StatusCategory(1), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(1) + ), + c.NcpConfig.GetZigbeeRole.Rsp( + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + DeviceRole=t.DeviceRole(2) + ), + c.NcpConfig.GetNwkKeys.Rsp( + partial=True, + TSN=11, + StatusCat=t.StatusCategory(2), + StatusCode=t.StatusCodeGeneric.OK, + KeyNumber1=10, + ), + c.NcpConfig.GetNwkKeys.Rsp( + partial=True, + TSN=11, + StatusCat=t.StatusCategory(2), + KeyNumber3=30, + ), + } diff --git a/zigpy_zboss/api.py b/zigpy_zboss/api.py index 12b57ee..c01ab15 100644 --- a/zigpy_zboss/api.py +++ b/zigpy_zboss/api.py @@ -2,22 +2,22 @@ from __future__ import annotations import asyncio -import logging -import itertools import contextlib -import zigpy.state +import itertools +import logging +from collections import Counter, defaultdict + import async_timeout -import zigpy_zboss.types as t -import zigpy_zboss.config as conf +import zigpy.state +import zigpy_zboss.config as conf +import zigpy_zboss.types as t +from zigpy_zboss import commands as c from zigpy_zboss import uart from zigpy_zboss.frames import Frame -from zigpy_zboss import commands as c from zigpy_zboss.nvram import NVRAMHelper -from collections import Counter, defaultdict -from zigpy_zboss.utils import IndicationListener -from zigpy_zboss.utils import BaseResponseListener -from zigpy_zboss.utils import OneShotResponseListener +from zigpy_zboss.utils import (BaseResponseListener, IndicationListener, + OneShotResponseListener) LOGGER = logging.getLogger(__name__) LISTENER_LOGGER = LOGGER.getChild("listener") @@ -49,8 +49,6 @@ def __init__(self, config: conf.ConfigType): self._listeners = defaultdict(list) self._blocking_request_lock = asyncio.Lock() - self.capabilities = None - self.nvram = NVRAMHelper(self) self.network_info: zigpy.state.NetworkInformation = None self.node_info: zigpy.state.NodeInfo = None @@ -118,6 +116,11 @@ def close(self) -> None: self._app = None self.version = None + for _, listeners in self._listeners.items(): + for listener in listeners: + listener.cancel() + self._listeners.clear() + if self._uart is not None: self._uart.close() self._uart = None @@ -187,6 +190,10 @@ async def request( raise ValueError( f"Cannot send a command that isn't a request: {request!r}") + if self._uart is None: + raise RuntimeError( + "Coordinator is disconnected, cannot send request") + LOGGER.debug("Sending request: %s", request) frame = request.to_frame() @@ -315,14 +322,11 @@ def register_indication_listener( async def version(self): """Get NCP module version.""" - if self._app is not None: - tsn = self._app.get_sequence() - else: - tsn = 0 + tsn = self._app.get_sequence() if self._app is not None else 0 req = c.NcpConfig.GetModuleVersion.Req(TSN=tsn) res = await self.request(req) if res.StatusCode: - return + return None version = ['', '', ''] for idx, ver in enumerate( [res.FWVersion, res.StackVersion, res.ProtocolVersion]): diff --git a/zigpy_zboss/commands/__init__.py b/zigpy_zboss/commands/__init__.py index 3e15e99..7680cbc 100644 --- a/zigpy_zboss/commands/__init__.py +++ b/zigpy_zboss/commands/__init__.py @@ -1,10 +1,10 @@ """Module importing all the commands.""" from .af import AF from .aps import APS -from .zdo import ZDO -from .security import SEC -from .nwk_mgmt import NWK from .ncp_config import NcpConfig +from .nwk_mgmt import NWK +from .security import SEC +from .zdo import ZDO ALL_COMMANDS = [ AF, diff --git a/zigpy_zboss/commands/aps.py b/zigpy_zboss/commands/aps.py index 7dfcf9e..707d59d 100644 --- a/zigpy_zboss/commands/aps.py +++ b/zigpy_zboss/commands/aps.py @@ -1,5 +1,6 @@ """Module defining all APS commands.""" import zigpy.types + import zigpy_zboss.types as t diff --git a/zigpy_zboss/commands/zdo.py b/zigpy_zboss/commands/zdo.py index a96b40e..8c0ebc9 100644 --- a/zigpy_zboss/commands/zdo.py +++ b/zigpy_zboss/commands/zdo.py @@ -1,9 +1,9 @@ """Module defining all ZDO commands.""" from __future__ import annotations -from zigpy.zdo import types as zdo_t import zigpy.types import zigpy.zdo.types +from zigpy.zdo import types as zdo_t import zigpy_zboss.types as t diff --git a/zigpy_zboss/config.py b/zigpy_zboss/config.py index 662a222..4b71ec5 100644 --- a/zigpy_zboss/config.py +++ b/zigpy_zboss/config.py @@ -1,25 +1,15 @@ """Module responsible for configuration.""" -import typing import numbers +import typing import voluptuous as vol -from zigpy.config import ( # noqa: F401 - CONF_NWK, - CONF_DEVICE, - CONF_NWK_KEY, - CONFIG_SCHEMA, - SCHEMA_DEVICE, - CONF_NWK_PAN_ID, - CONF_NWK_CHANNEL, - CONF_DEVICE_PATH, - CONF_NWK_KEY_SEQ, - CONF_NWK_CHANNELS, - CONF_NWK_UPDATE_ID, - CONF_NWK_TC_ADDRESS, - CONF_NWK_TC_LINK_KEY, - CONF_NWK_EXTENDED_PAN_ID, - cv_boolean, -) +from zigpy.config import (CONF_DEVICE, CONF_DEVICE_PATH, # noqa: F401 + CONF_NWK, CONF_NWK_CHANNEL, CONF_NWK_CHANNELS, + CONF_NWK_EXTENDED_PAN_ID, CONF_NWK_KEY, + CONF_NWK_KEY_SEQ, CONF_NWK_PAN_ID, + CONF_NWK_TC_ADDRESS, CONF_NWK_TC_LINK_KEY, + CONF_NWK_UPDATE_ID, CONFIG_SCHEMA, SCHEMA_DEVICE, + cv_boolean) LOG_FILE_NAME = "zigpy-zboss.log" SERIAL_LOG_FILE_NAME = "serial-zigpy-zboss.log" diff --git a/zigpy_zboss/debug.py b/zigpy_zboss/debug.py index 450a3e6..4ef2214 100644 --- a/zigpy_zboss/debug.py +++ b/zigpy_zboss/debug.py @@ -1,10 +1,11 @@ """Module setting up a debugging serial connection with the NCP.""" -import serial import asyncio import logging +import logging.handlers + import async_timeout +import serial import serial_asyncio -import logging.handlers from zigpy_zboss import types as t diff --git a/zigpy_zboss/frames.py b/zigpy_zboss/frames.py index 14229d8..56ad739 100644 --- a/zigpy_zboss/frames.py +++ b/zigpy_zboss/frames.py @@ -4,10 +4,8 @@ import dataclasses import zigpy_zboss.types as t +from zigpy_zboss.checksum import CRC8, CRC16 from zigpy_zboss.exceptions import InvalidFrame -from zigpy_zboss.checksum import CRC8 -from zigpy_zboss.checksum import CRC16 - ZBNCP_LL_BODY_SIZE_MAX = 247 # Check zbncp_ll_pkt.h in ZBOSS NCP host src diff --git a/zigpy_zboss/nvram.py b/zigpy_zboss/nvram.py index 66fe19d..234fa42 100644 --- a/zigpy_zboss/nvram.py +++ b/zigpy_zboss/nvram.py @@ -1,8 +1,8 @@ """NCP NVRAM related helpers.""" import logging -import zigpy_zboss.types as t import zigpy_zboss.commands as c +import zigpy_zboss.types as t LOGGER = logging.getLogger(__name__) WRITE_DS_LENGTH = 280 @@ -24,7 +24,7 @@ async def read(self, nv_id: t.DatasetId, item_type): ) ) if res.StatusCode != 0: - return + return None if not res.DatasetId == nv_id: raise diff --git a/zigpy_zboss/tools/factory_reset_ncp.py b/zigpy_zboss/tools/factory_reset_ncp.py index be98a2f..9f4b6bd 100644 --- a/zigpy_zboss/tools/factory_reset_ncp.py +++ b/zigpy_zboss/tools/factory_reset_ncp.py @@ -1,11 +1,11 @@ """Script to factory reset the coordinator.""" +import asyncio import sys + import serial -import asyncio -from zigpy_zboss.api import ZBOSS from zigpy_zboss import types as t - +from zigpy_zboss.api import ZBOSS from zigpy_zboss.tools.config import get_config diff --git a/zigpy_zboss/tools/get_ncp_version.py b/zigpy_zboss/tools/get_ncp_version.py index 594890a..bb9ec0a 100644 --- a/zigpy_zboss/tools/get_ncp_version.py +++ b/zigpy_zboss/tools/get_ncp_version.py @@ -1,9 +1,9 @@ """Script to print the NCP firmware version.""" -import serial import asyncio -from zigpy_zboss.api import ZBOSS +import serial +from zigpy_zboss.api import ZBOSS from zigpy_zboss.tools.config import get_config diff --git a/zigpy_zboss/types/__init__.py b/zigpy_zboss/types/__init__.py index 210722b..a78217f 100644 --- a/zigpy_zboss/types/__init__.py +++ b/zigpy_zboss/types/__init__.py @@ -1,7 +1,7 @@ """Module importing all types.""" from .basic import * # noqa: F401, F403 +from .commands import * # noqa: F401, F403 +from .cstruct import * # noqa: F401, F403 from .named import * # noqa: F401, F403 from .nvids import * # noqa: F401, F403 -from .cstruct import * # noqa: F401, F403 from .structs import * # noqa: F401, F403 -from .commands import * # noqa: F401, F403 diff --git a/zigpy_zboss/types/basic.py b/zigpy_zboss/types/basic.py index 2eca2d7..d8f270a 100644 --- a/zigpy_zboss/types/basic.py +++ b/zigpy_zboss/types/basic.py @@ -1,8 +1,9 @@ """Module defining basic types.""" from __future__ import annotations + import typing -from zigpy.types import int8s, uint8_t, enum_factory # noqa: F401 +from zigpy.types import enum_factory, int8s, uint8_t # noqa: F401 from zigpy_zboss.types.cstruct import CStruct @@ -31,18 +32,9 @@ class bitmap16(enum.IntFlag): """Bitmap with 16 bits value.""" else: - from zigpy.types import ( # noqa: F401 - enum8, - enum16, - bitmap8, - bitmap16, - uint16_t, - uint24_t, - uint32_t, - uint40_t, - uint56_t, - uint64_t, - ) + from zigpy.types import (bitmap8, bitmap16, enum8, enum16, # noqa: F401 + uint16_t, uint24_t, uint32_t, uint40_t, uint56_t, + uint64_t) class enum24(enum_factory(uint24_t)): """Enum with 24 bits value.""" diff --git a/zigpy_zboss/types/commands.py b/zigpy_zboss/types/commands.py index 7406afc..8ec8b9a 100644 --- a/zigpy_zboss/types/commands.py +++ b/zigpy_zboss/types/commands.py @@ -1,12 +1,14 @@ """Module defining types used for commands.""" from __future__ import annotations + +import dataclasses import enum import logging -import dataclasses import zigpy.zdo.types -import zigpy_zboss.types as t +import zigpy_zboss.types.basic as t +import zigpy_zboss.types.named as t_named LOGGER = logging.getLogger(__name__) TYPE_ZBOSS_NCP_API_HL = t.uint8_t(0x06) @@ -421,7 +423,8 @@ def __init__(self, *, partial=False, **params): issubclass(param.type, (t.ShortBytes, t.LongBytes)), isinstance(value, list) and issubclass(param.type, list), - isinstance(value, bool) and issubclass(param.type, t.Bool), + isinstance( + value, bool) and issubclass(param.type, t_named.Bool), ] # fmt: on @@ -455,9 +458,7 @@ def to_frame(self, *, align=False): if self._partial: raise ValueError(f"Cannot serialize a partial frame: {self}") - from zigpy_zboss.frames import HLPacket - from zigpy_zboss.frames import LLHeader - from zigpy_zboss.frames import Frame + from zigpy_zboss.frames import Frame, HLPacket, LLHeader chunks = [] @@ -524,6 +525,11 @@ def from_frame(cls, frame, *, align=False) -> "CommandBase": else: # Otherwise, let the exception happen raise + if data: + raise ValueError( + f"Frame {frame} contains trailing data after parsing: {data}" + ) + return cls(**params) def matches(self, other: "CommandBase") -> bool: @@ -715,7 +721,8 @@ class Relationship(t.enum8): STATUS_SCHEMA = ( - t.Param("TSN", t.uint8_t, "Transmit Sequence Number"), - t.Param("StatusCat", StatusCategory, "Status category code"), - t.Param("StatusCode", StatusCodeGeneric, "Status code inside category"), + t_named.Param("TSN", t.uint8_t, "Transmit Sequence Number"), + t_named.Param("StatusCat", StatusCategory, "Status category code"), + t_named.Param( + "StatusCode", StatusCodeGeneric, "Status code inside category"), ) diff --git a/zigpy_zboss/types/cstruct.py b/zigpy_zboss/types/cstruct.py index fb306e7..fa560b0 100644 --- a/zigpy_zboss/types/cstruct.py +++ b/zigpy_zboss/types/cstruct.py @@ -1,9 +1,9 @@ """Module defining cstruct types.""" from __future__ import annotations -import typing -import inspect import dataclasses +import inspect +import typing import zigpy.types as zigpy_t diff --git a/zigpy_zboss/types/named.py b/zigpy_zboss/types/named.py index f6adf6a..4e50415 100644 --- a/zigpy_zboss/types/named.py +++ b/zigpy_zboss/types/named.py @@ -1,24 +1,13 @@ """Module defining named types.""" from __future__ import annotations -import typing -import logging import dataclasses +import logging +import typing -from zigpy.types import ( # noqa: F401 - NWK, - List, - Bool, - PanId, - EUI64, - Struct, - bitmap8, - KeyData, - Channels, - ClusterId, - ExtendedPanId, - CharacterString, -) +from zigpy.types import (EUI64, NWK, Bool, Channels, # noqa: F401 + CharacterString, ClusterId, ExtendedPanId, KeyData, + List, PanId, Struct, bitmap8) from . import basic diff --git a/zigpy_zboss/types/nvids.py b/zigpy_zboss/types/nvids.py index 492b6e3..9abd972 100644 --- a/zigpy_zboss/types/nvids.py +++ b/zigpy_zboss/types/nvids.py @@ -1,7 +1,10 @@ """Module defining zboss nvram types.""" from __future__ import annotations + import zigpy.types as t + import zigpy_zboss.types as zboss_t + from . import basic @@ -105,7 +108,7 @@ def serialize(self, *, align=False) -> bytes: header = self._header( byte_count=byte_count, entry_count=len(self), - version=t.uint8_t(0x02), + version=self.version, _align=t.uint16_t(0x0000), ) return header.serialize() + serialized_items diff --git a/zigpy_zboss/types/structs.py b/zigpy_zboss/types/structs.py index b86ec4f..803c82e 100644 --- a/zigpy_zboss/types/structs.py +++ b/zigpy_zboss/types/structs.py @@ -1,5 +1,6 @@ """Module defining struct types.""" import zigpy.types as t + from . import basic diff --git a/zigpy_zboss/uart.py b/zigpy_zboss/uart.py index a6daf75..af4e39d 100644 --- a/zigpy_zboss/uart.py +++ b/zigpy_zboss/uart.py @@ -1,15 +1,17 @@ """Module that connects and sends/receives bytes from the nRF52 SoC.""" -import typing import asyncio import logging -import zigpy.serial +import typing + import async_timeout +import zigpy.serial + import zigpy_zboss.config as conf from zigpy_zboss import types as t -from zigpy_zboss.frames import Frame from zigpy_zboss.checksum import CRC8 -from zigpy_zboss.logger import SERIAL_LOGGER from zigpy_zboss.exceptions import InvalidFrame +from zigpy_zboss.frames import Frame +from zigpy_zboss.logger import SERIAL_LOGGER LOGGER = logging.getLogger(__name__) ACK_TIMEOUT = 1 diff --git a/zigpy_zboss/utils.py b/zigpy_zboss/utils.py index 4346572..71f1961 100644 --- a/zigpy_zboss/utils.py +++ b/zigpy_zboss/utils.py @@ -1,10 +1,11 @@ """Module defining utility functions.""" from __future__ import annotations -import typing import asyncio -import logging import dataclasses +import logging +import typing + import zigpy_zboss.types as t LOGGER = logging.getLogger(__name__) diff --git a/zigpy_zboss/zigbee/application.py b/zigpy_zboss/zigbee/application.py index 73bdf4f..63aa52c 100644 --- a/zigpy_zboss/zigbee/application.py +++ b/zigpy_zboss/zigbee/application.py @@ -1,28 +1,30 @@ """ControllerApplication for ZBOSS NCP protocol based adapters.""" from __future__ import annotations -import logging import asyncio -import zigpy.util -import zigpy.state +import logging +from typing import Any, Dict + import zigpy.appdb +import zigpy.application import zigpy.config import zigpy.device import zigpy.endpoint import zigpy.exceptions +import zigpy.state import zigpy.types as t -import zigpy.application -import zigpy_zboss.types as t_zboss +import zigpy.util import zigpy.zdo.types as zdo_t -import zigpy_zboss.config as conf +from zigpy.exceptions import DeliveryError -from typing import Any, Dict -from zigpy_zboss.api import ZBOSS +import zigpy_zboss.config as conf +import zigpy_zboss.types as t_zboss from zigpy_zboss import commands as c -from zigpy.exceptions import DeliveryError -from .device import ZbossCoordinator, ZbossDevice +from zigpy_zboss.api import ZBOSS from zigpy_zboss.config import CONFIG_SCHEMA, SCHEMA_DEVICE +from .device import ZbossCoordinator, ZbossDevice + LOGGER = logging.getLogger(__name__) PROBE_TIMEOUT = 5 @@ -653,8 +655,8 @@ async def send_packet(self, packet: t.ZigbeePacket) -> None: DstAddr=dst_addr, ProfileID=packet.profile_id, ClusterId=packet.cluster_id, - DstEndpoint=packet.dst_ep, - SrcEndpoint=packet.src_ep, + DstEndpoint=packet.dst_ep or 0, + SrcEndpoint=packet.src_ep or 0, Radius=packet.radius or 0, DstAddrMode=dst_addr_mode, TxOptions=options, diff --git a/zigpy_zboss/zigbee/device.py b/zigpy_zboss/zigbee/device.py index 875b69c..6525f21 100644 --- a/zigpy_zboss/zigbee/device.py +++ b/zigpy_zboss/zigbee/device.py @@ -1,15 +1,16 @@ """Zigbee device object.""" import logging -import zigpy.util +from typing import Any + import zigpy.device import zigpy.endpoint import zigpy.types as t -import zigpy_zboss.types as t_zboss -from typing import Any -from zigpy_zboss import commands as c -from zigpy.zdo import types as zdo_t +import zigpy.util from zigpy.zdo import ZDO as ZigpyZDO +from zigpy.zdo import types as zdo_t +import zigpy_zboss.types as t_zboss +from zigpy_zboss import commands as c LOGGER = logging.getLogger(__name__)