Skip to content

Commit ac9110a

Browse files
authored
Improve typing (gijzelaerr#516)
Co-authored-by: nikteliy <[email protected]>
1 parent 8612423 commit ac9110a

22 files changed

+765
-425
lines changed

.pre-commit-config.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ repos:
2020
rev: 'v1.10.0'
2121
hooks:
2222
- id: mypy
23-
additional_dependencies: [types-setuptools]
23+
additional_dependencies: [types-setuptools, types-click]
2424
files: ^snap7
2525

2626
- repo: https://github.com/astral-sh/ruff-pre-commit
2727
rev: 'v0.4.2'
2828
hooks:
2929
- id: ruff
3030
- id: ruff-format
31+
exclude: "snap7/protocol.py"

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Homepage = "https://github.com/gijzelaerr/python-snap7"
3232
Documentation = "https://python-snap7.readthedocs.io/en/latest/"
3333

3434
[project.optional-dependencies]
35-
test = ["pytest", "mypy", "types-setuptools", "ruff"]
35+
test = ["pytest", "mypy", "types-setuptools", "ruff", "types-click"]
3636
cli = ["rich", "click" ]
3737
doc = ["sphinx", "sphinx_rtd_theme"]
3838

@@ -59,6 +59,9 @@ markers =[
5959

6060
[tool.mypy]
6161
ignore_missing_imports = true
62+
strict = true
63+
# https://github.com/python/mypy/issues/2427#issuecomment-1419206807
64+
disable_error_code = ["method-assign", "attr-defined"]
6265

6366
[tool.ruff]
6467
output-format = "full"

snap7/client/__init__.py

+48-40
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import re
66
import logging
77
from ctypes import CFUNCTYPE, byref, create_string_buffer, sizeof
8-
from ctypes import Array, c_byte, c_char_p, c_int, c_int32, c_uint16, c_ulong, c_void_p
8+
from ctypes import Array, _SimpleCData, c_byte, c_char_p, c_int, c_int32, c_uint16, c_ulong, c_void_p
99
from datetime import datetime
10-
from typing import Any, Callable, List, Optional, Tuple, Union
10+
from typing import Any, Callable, Hashable, List, Optional, Tuple, Union, Type
11+
from types import TracebackType
1112

1213
from ..common import check_error, ipv4, load_library
14+
from ..protocol import Snap7CliProtocol
1315
from ..types import S7SZL, Areas, BlocksList, S7CpInfo, S7CpuInfo, S7DataItem
1416
from ..types import S7OrderCode, S7Protection, S7SZLList, TS7BlockInfo, WordLen
1517
from ..types import S7Object, buffer_size, buffer_type, cpu_statuses, param_types
@@ -18,11 +20,11 @@
1820
logger = logging.getLogger(__name__)
1921

2022

21-
def error_wrap(func):
23+
def error_wrap(func: Callable[..., Any]) -> Callable[..., Any]:
2224
"""Parses a s7 error code returned the decorated function."""
2325

24-
def f(*args, **kw):
25-
code = func(*args, **kw)
26+
def f(*args: tuple[Any, ...], **kwargs: dict[Hashable, Any]) -> None:
27+
code = func(*args, **kwargs)
2628
check_error(code, context="client")
2729

2830
return f
@@ -47,10 +49,10 @@ class Client:
4749
>>> client.db_write(1, 0, data)
4850
"""
4951

50-
_lib: Any # since this is dynamically loaded from a DLL we don't have the type signature.
52+
_lib: Snap7CliProtocol
5153
_read_callback = None
5254
_callback = None
53-
_s7_client: Optional[S7Object] = None
55+
_s7_client: S7Object
5456

5557
def __init__(self, lib_location: Optional[str] = None):
5658
"""Creates a new `Client` instance.
@@ -66,22 +68,24 @@ def __init__(self, lib_location: Optional[str] = None):
6668
<snap7.client.Client object at 0x0000028B257128E0>
6769
"""
6870

69-
self._lib = load_library(lib_location)
71+
self._lib: Snap7CliProtocol = load_library(lib_location)
7072
self.create()
7173

72-
def __enter__(self):
74+
def __enter__(self) -> "Client":
7375
return self
7476

75-
def __exit__(self, exc_type, exc_val, exc_tb):
77+
def __exit__(
78+
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
79+
) -> None:
7680
self.destroy()
7781

78-
def __del__(self):
82+
def __del__(self) -> None:
7983
self.destroy()
8084

81-
def create(self):
85+
def create(self) -> None:
8286
"""Creates a SNAP7 client."""
8387
logger.info("creating snap7 client")
84-
self._lib.Cli_Create.restype = S7Object
88+
self._lib.Cli_Create.restype = S7Object # type: ignore[attr-defined]
8589
self._s7_client = S7Object(self._lib.Cli_Create())
8690

8791
def destroy(self) -> Optional[int]:
@@ -97,7 +101,7 @@ def destroy(self) -> Optional[int]:
97101
logger.info("destroying snap7 client")
98102
if self._lib and self._s7_client is not None:
99103
return self._lib.Cli_Destroy(byref(self._s7_client))
100-
self._s7_client = None
104+
self._s7_client = None # type: ignore[assignment]
101105
return None
102106

103107
def plc_stop(self) -> int:
@@ -199,7 +203,7 @@ def connect(self, address: str, rack: int, slot: int, tcpport: int = 102) -> int
199203
"""
200204
logger.info(f"connecting to {address}:{tcpport} rack {rack} slot {slot}")
201205

202-
self.set_param(RemotePort, tcpport)
206+
self.set_param(number=RemotePort, value=tcpport)
203207
return self._lib.Cli_ConnectTo(self._s7_client, c_char_p(address.encode()), c_int(rack), c_int(slot))
204208

205209
def db_read(self, db_number: int, start: int, size: int) -> bytearray:
@@ -441,7 +445,7 @@ def write_area(self, area: Areas, dbnumber: int, start: int, data: bytearray) ->
441445
cdata = (type_ * len(data)).from_buffer_copy(data)
442446
return self._lib.Cli_WriteArea(self._s7_client, area.value, dbnumber, start, size, wordlen.value, byref(cdata))
443447

444-
def read_multi_vars(self, items) -> Tuple[int, S7DataItem]:
448+
def read_multi_vars(self, items: Array[S7DataItem]) -> Tuple[int, Array[S7DataItem]]:
445449
"""Reads different kind of variables from a PLC simultaneously.
446450
447451
Args:
@@ -472,7 +476,7 @@ def list_blocks(self) -> BlocksList:
472476
logger.debug(f"blocks: {blocksList}")
473477
return blocksList
474478

475-
def list_blocks_of_type(self, blocktype: str, size: int) -> Union[int, Array]:
479+
def list_blocks_of_type(self, blocktype: str, size: int) -> Union[int, Array[c_uint16]]:
476480
"""This function returns the AG list of a specified block type.
477481
478482
Args:
@@ -592,11 +596,11 @@ def set_connection_params(self, address: str, local_tsap: int, remote_tsap: int)
592596
"""
593597
if not re.match(ipv4, address):
594598
raise ValueError(f"{address} is invalid ipv4")
595-
result = self._lib.Cli_SetConnectionParams(self._s7_client, address, c_uint16(local_tsap), c_uint16(remote_tsap))
599+
result = self._lib.Cli_SetConnectionParams(self._s7_client, address.encode(), c_uint16(local_tsap), c_uint16(remote_tsap))
596600
if result != 0:
597601
raise ValueError("The parameter was invalid")
598602

599-
def set_connection_type(self, connection_type: int):
603+
def set_connection_type(self, connection_type: int) -> None:
600604
"""Sets the connection resource type, i.e the way in which the Clients connects to a PLC.
601605
602606
Args:
@@ -659,7 +663,7 @@ def ab_write(self, start: int, data: bytearray) -> int:
659663
logger.debug(f"ab write: start: {start}: size: {size}: ")
660664
return self._lib.Cli_ABWrite(self._s7_client, start, size, byref(cdata))
661665

662-
def as_ab_read(self, start: int, size: int, data) -> int:
666+
def as_ab_read(self, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
663667
"""Reads a part of IPU area from a PLC asynchronously.
664668
665669
Args:
@@ -720,7 +724,7 @@ def as_copy_ram_to_rom(self, timeout: int = 1) -> int:
720724
check_error(result, context="client")
721725
return result
722726

723-
def as_ct_read(self, start: int, amount: int, data) -> int:
727+
def as_ct_read(self, start: int, amount: int, data: "Array[_SimpleCData[Any]]") -> int:
724728
"""Reads counters from a PLC asynchronously.
725729
726730
Args:
@@ -752,7 +756,7 @@ def as_ct_write(self, start: int, amount: int, data: bytearray) -> int:
752756
check_error(result, context="client")
753757
return result
754758

755-
def as_db_fill(self, db_number: int, filler) -> int:
759+
def as_db_fill(self, db_number: int, filler: int) -> int:
756760
"""Fills a DB in AG with a given byte.
757761
758762
Args:
@@ -766,7 +770,7 @@ def as_db_fill(self, db_number: int, filler) -> int:
766770
check_error(result, context="client")
767771
return result
768772

769-
def as_db_get(self, db_number: int, _buffer, size) -> bytearray:
773+
def as_db_get(self, db_number: int, _buffer: "Array[_SimpleCData[Any]]", size: "_SimpleCData[Any]") -> int:
770774
"""Uploads a DB from AG using DBRead.
771775
772776
Note:
@@ -784,7 +788,7 @@ def as_db_get(self, db_number: int, _buffer, size) -> bytearray:
784788
check_error(result, context="client")
785789
return result
786790

787-
def as_db_read(self, db_number: int, start: int, size: int, data) -> Array:
791+
def as_db_read(self, db_number: int, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
788792
"""Reads a part of a DB from a PLC.
789793
790794
Args:
@@ -807,7 +811,7 @@ def as_db_read(self, db_number: int, start: int, size: int, data) -> Array:
807811
check_error(result, context="client")
808812
return result
809813

810-
def as_db_write(self, db_number: int, start: int, size: int, data) -> int:
814+
def as_db_write(self, db_number: int, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
811815
"""Writes a part of a DB into a PLC.
812816
813817
Args:
@@ -943,7 +947,7 @@ def set_plc_datetime(self, dt: datetime) -> int:
943947

944948
return self._lib.Cli_SetPlcDateTime(self._s7_client, byref(buffer))
945949

946-
def check_as_completion(self, p_value) -> int:
950+
def check_as_completion(self, p_value: c_int) -> int:
947951
"""Method to check Status of an async request. Result contains if the check was successful, not the data value itself
948952
949953
Args:
@@ -952,7 +956,7 @@ def check_as_completion(self, p_value) -> int:
952956
Returns:
953957
Snap7 code. If 0 - Job is done successfully. If 1 - Job is either pending or contains s7errors
954958
"""
955-
result = self._lib.Cli_CheckAsCompletion(self._s7_client, p_value)
959+
result = self._lib.Cli_CheckAsCompletion(self._s7_client, byref(p_value))
956960
check_error(result, context="client")
957961
return result
958962

@@ -1000,7 +1004,7 @@ def wait_as_completion(self, timeout: int) -> int:
10001004
check_error(result, context="client")
10011005
return result
10021006

1003-
def _prepare_as_read_area(self, area: Areas, size: int) -> Tuple[WordLen, Array]:
1007+
def _prepare_as_read_area(self, area: Areas, size: int) -> Tuple[WordLen, "Array[_SimpleCData[int]]"]:
10041008
if area not in Areas:
10051009
raise ValueError(f"{area} is not implemented in types")
10061010
elif area == Areas.TM:
@@ -1013,7 +1017,9 @@ def _prepare_as_read_area(self, area: Areas, size: int) -> Tuple[WordLen, Array]
10131017
usrdata = (type_ * size)()
10141018
return wordlen, usrdata
10151019

1016-
def as_read_area(self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata) -> int:
1020+
def as_read_area(
1021+
self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata: "Array[_SimpleCData[Any]]"
1022+
) -> int:
10171023
"""Reads a data area from a PLC asynchronously.
10181024
With it you can read DB, Inputs, Outputs, Merkers, Timers and Counters.
10191025
@@ -1032,11 +1038,11 @@ def as_read_area(self, area: Areas, dbnumber: int, start: int, size: int, wordle
10321038
f"reading area: {area.name} dbnumber: {dbnumber} start: {start} amount: {size} "
10331039
f"wordlen: {wordlen.name}={wordlen.value}"
10341040
)
1035-
result = self._lib.Cli_AsReadArea(self._s7_client, area.value, dbnumber, start, size, wordlen.value, pusrdata)
1041+
result = self._lib.Cli_AsReadArea(self._s7_client, area.value, dbnumber, start, size, wordlen.value, byref(pusrdata))
10361042
check_error(result, context="client")
10371043
return result
10381044

1039-
def _prepare_as_write_area(self, area: Areas, data: bytearray) -> Tuple[WordLen, Array]:
1045+
def _prepare_as_write_area(self, area: Areas, data: bytearray) -> Tuple[WordLen, "Array[_SimpleCData[Any]]"]:
10401046
if area not in Areas:
10411047
raise ValueError(f"{area} is not implemented in types")
10421048
elif area == Areas.TM:
@@ -1049,7 +1055,9 @@ def _prepare_as_write_area(self, area: Areas, data: bytearray) -> Tuple[WordLen,
10491055
cdata = (type_ * len(data)).from_buffer_copy(data)
10501056
return wordlen, cdata
10511057

1052-
def as_write_area(self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata) -> int:
1058+
def as_write_area(
1059+
self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata: "Array[_SimpleCData[Any]]"
1060+
) -> int:
10531061
"""Writes a data area into a PLC asynchronously.
10541062
10551063
Args:
@@ -1072,7 +1080,7 @@ def as_write_area(self, area: Areas, dbnumber: int, start: int, size: int, wordl
10721080
check_error(res, context="client")
10731081
return res
10741082

1075-
def as_eb_read(self, start: int, size: int, data) -> int:
1083+
def as_eb_read(self, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
10761084
"""Reads a part of IPI area from a PLC asynchronously.
10771085
10781086
Args:
@@ -1124,7 +1132,7 @@ def as_full_upload(self, _type: str, block_num: int) -> int:
11241132
check_error(result, context="client")
11251133
return result
11261134

1127-
def as_list_blocks_of_type(self, blocktype: str, data, count) -> int:
1135+
def as_list_blocks_of_type(self, blocktype: str, data: "Array[_SimpleCData[Any]]", count: "_SimpleCData[Any]") -> int:
11281136
"""Returns the AG blocks list of a given type.
11291137
11301138
Args:
@@ -1145,7 +1153,7 @@ def as_list_blocks_of_type(self, blocktype: str, data, count) -> int:
11451153
check_error(result, context="client")
11461154
return result
11471155

1148-
def as_mb_read(self, start: int, size: int, data) -> int:
1156+
def as_mb_read(self, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
11491157
"""Reads a part of Merkers area from a PLC.
11501158
11511159
Args:
@@ -1177,7 +1185,7 @@ def as_mb_write(self, start: int, size: int, data: bytearray) -> int:
11771185
check_error(result, context="client")
11781186
return result
11791187

1180-
def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size) -> int:
1188+
def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size: "_SimpleCData[Any]") -> int:
11811189
"""Reads a partial list of given ID and Index.
11821190
11831191
Args:
@@ -1193,7 +1201,7 @@ def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size) -> int:
11931201
check_error(result, context="client")
11941202
return result
11951203

1196-
def as_read_szl_list(self, szl_list, items_count) -> int:
1204+
def as_read_szl_list(self, szl_list: S7SZLList, items_count: "_SimpleCData[Any]") -> int:
11971205
"""Reads the list of partial lists available in the CPU.
11981206
11991207
Args:
@@ -1207,7 +1215,7 @@ def as_read_szl_list(self, szl_list, items_count) -> int:
12071215
check_error(result, context="client")
12081216
return result
12091217

1210-
def as_tm_read(self, start: int, amount: int, data) -> bytearray:
1218+
def as_tm_read(self, start: int, amount: int, data: "Array[_SimpleCData[Any]]") -> int:
12111219
"""Reads timers from a PLC.
12121220
12131221
Args:
@@ -1239,7 +1247,7 @@ def as_tm_write(self, start: int, amount: int, data: bytearray) -> int:
12391247
check_error(result)
12401248
return result
12411249

1242-
def as_upload(self, block_num: int, _buffer, size) -> int:
1250+
def as_upload(self, block_num: int, _buffer: "Array[_SimpleCData[Any]]", size: "_SimpleCData[Any]") -> int:
12431251
"""Uploads a block from AG.
12441252
12451253
Note:
@@ -1363,7 +1371,7 @@ def error_text(self, error: int) -> str:
13631371
text_length = c_int(256)
13641372
error_code = c_int32(error)
13651373
text = create_string_buffer(buffer_size)
1366-
response = self._lib.Cli_ErrorText(error_code, byref(text), text_length)
1374+
response = self._lib.Cli_ErrorText(error_code, text, text_length)
13671375
check_error(response)
13681376
result = bytearray(text)[: text_length.value].decode().strip("\x00")
13691377
return result

0 commit comments

Comments
 (0)