diff --git a/adafruit_esp32spi/adafruit_esp32spi.py b/adafruit_esp32spi/adafruit_esp32spi.py index e2c16ea..aa97977 100644 --- a/adafruit_esp32spi/adafruit_esp32spi.py +++ b/adafruit_esp32spi/adafruit_esp32spi.py @@ -844,7 +844,8 @@ def socket_connected(self, socket_num): return self.socket_status(socket_num) == SOCKET_ESTABLISHED def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE): - """Write the bytearray buffer to a socket""" + """Write the bytearray buffer to a socket. + Returns the number of bytes written""" if self._debug: print("Writing:", buffer) self._socknum_ll[0][0] = socket_num @@ -874,7 +875,7 @@ def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE): resp = self._send_command_get_response(_SEND_UDP_DATA_CMD, self._socknum_ll) if resp[0][0] != 1: raise ConnectionError("Failed to send UDP data") - return + return sent if sent != len(buffer): self.socket_close(socket_num) @@ -886,6 +887,8 @@ def socket_write(self, socket_num, buffer, conn_mode=TCP_MODE): if resp[0][0] != 1: raise ConnectionError("Failed to verify data sent") + return sent + def socket_available(self, socket_num): """Determine how many bytes are waiting to be read on the socket""" self._socknum_ll[0][0] = socket_num diff --git a/adafruit_esp32spi/adafruit_esp32spi_socketpool.py b/adafruit_esp32spi/adafruit_esp32spi_socketpool.py index 6d8a32e..29c918a 100644 --- a/adafruit_esp32spi/adafruit_esp32spi_socketpool.py +++ b/adafruit_esp32spi/adafruit_esp32spi_socketpool.py @@ -13,7 +13,7 @@ from __future__ import annotations try: - from typing import TYPE_CHECKING, Optional + from typing import TYPE_CHECKING, Optional, Tuple if TYPE_CHECKING: from esp32spi.adafruit_esp32spi import ESP_SPIcontrol @@ -33,11 +33,15 @@ class SocketPool: """ESP32SPI SocketPool library""" - SOCK_STREAM = const(0) - SOCK_DGRAM = const(1) + # socketpool constants + SOCK_STREAM = const(1) + SOCK_DGRAM = const(2) AF_INET = const(2) - NO_SOCKET_AVAIL = const(255) + SOL_SOCKET = const(0xFFF) + SO_REUSEADDR = const(0x0004) + # implementation specific constants + NO_SOCKET_AVAIL = const(255) MAX_PACKET = const(4000) def __new__(cls, iface: ESP_SPIcontrol): @@ -72,7 +76,13 @@ def socket( # pylint: disable=redefined-builtin class Socket: """A simplified implementation of the Python 'socket' class, for connecting - through an interface to a remote device""" + through an interface to a remote device. Has properties specific to the + implementation. + + :param SocketPool socket_pool: The underlying socket pool. + :param Optional[int] socknum: Allows wrapping a Socket instance around a socket + number returned by the nina firmware. Used internally. + """ def __init__( # pylint: disable=redefined-builtin,too-many-arguments,unused-argument self, @@ -81,6 +91,7 @@ def __init__( # pylint: disable=redefined-builtin,too-many-arguments,unused-arg type: int = SocketPool.SOCK_STREAM, proto: int = 0, fileno: Optional[int] = None, + socknum: Optional[int] = None, ): if family != SocketPool.AF_INET: raise ValueError("Only AF_INET family supported") @@ -88,7 +99,8 @@ def __init__( # pylint: disable=redefined-builtin,too-many-arguments,unused-arg self._interface = self._socket_pool._interface self._type = type self._buffer = b"" - self._socknum = self._interface.get_socket() + self._socknum = socknum if socknum is not None else self._interface.get_socket() + self._bound = () self.settimeout(0) def __enter__(self): @@ -122,13 +134,14 @@ def send(self, data): conntype = self._interface.UDP_MODE else: conntype = self._interface.TCP_MODE - self._interface.socket_write(self._socknum, data, conn_mode=conntype) + sent = self._interface.socket_write(self._socknum, data, conn_mode=conntype) gc.collect() + return sent def sendto(self, data, address): """Connect and send some data to the socket.""" self.connect(address) - self.send(data) + return self.send(data) def recv(self, bufsize: int) -> bytes: """Reads some bytes from the connected remote address. Will only return @@ -151,12 +164,12 @@ def recv_into(self, buffer, nbytes: int = 0): if not 0 <= nbytes <= len(buffer): raise ValueError("nbytes must be 0 to len(buffer)") - last_read_time = time.monotonic() + last_read_time = time.monotonic_ns() num_to_read = len(buffer) if nbytes == 0 else nbytes num_read = 0 while num_to_read > 0: # we might have read socket data into the self._buffer with: - # esp32spi_wsgiserver: socket_readline + # adafruit_wsgi.esp32spi_wsgiserver: socket_readline if len(self._buffer) > 0: bytes_to_read = min(num_to_read, len(self._buffer)) buffer[num_read : num_read + bytes_to_read] = self._buffer[ @@ -170,7 +183,7 @@ def recv_into(self, buffer, nbytes: int = 0): num_avail = self._available() if num_avail > 0: - last_read_time = time.monotonic() + last_read_time = time.monotonic_ns() bytes_read = self._interface.socket_read( self._socknum, min(num_to_read, num_avail) ) @@ -181,15 +194,27 @@ def recv_into(self, buffer, nbytes: int = 0): # We got a message, but there are no more bytes to read, so we can stop. break # No bytes yet, or more bytes requested. - if self._timeout > 0 and time.monotonic() - last_read_time > self._timeout: + + if self._timeout == 0: # if in non-blocking mode, stop now. + break + + # Time out if there's a positive timeout set. + delta = (time.monotonic_ns() - last_read_time) // 1_000_000 + if self._timeout > 0 and delta > self._timeout: raise OSError(errno.ETIMEDOUT) return num_read def settimeout(self, value): - """Set the read timeout for sockets. - If value is 0 socket reads will block until a message is available. + """Set the read timeout for sockets in seconds. + ``0`` means non-blocking. ``None`` means block indefinitely. """ - self._timeout = value + if value is None: + self._timeout = -1 + else: + if value < 0: + raise ValueError("Timeout cannot be a negative number") + # internally in milliseconds as an int + self._timeout = int(value * 1000) def _available(self): """Returns how many bytes of data are available to be read (up to the MAX_PACKET length)""" @@ -224,3 +249,48 @@ def _connected(self): def close(self): """Close the socket, after reading whatever remains""" self._interface.socket_close(self._socknum) + + #################################################################### + # WORK IN PROGRESS + #################################################################### + + def accept(self): + """Accept a connection on a listening socket of type SOCK_STREAM, + creating a new socket of type SOCK_STREAM. Returns a tuple of + (new_socket, remote_address) + """ + client_sock_num = self._interface.socket_available(self._socknum) + if client_sock_num != SocketPool.NO_SOCKET_AVAIL: + sock = Socket(self._socket_pool, socknum=client_sock_num) + # get remote information (addr and port) + remote = self._interface.get_remote_data(client_sock_num) + ip_address = "{}.{}.{}.{}".format(*remote["ip_addr"]) + port = remote["port"] + client_address = (ip_address, port) + return sock, client_address + raise OSError(errno.ECONNRESET) + + def bind(self, address: Tuple[str, int]): + """Bind a socket to an address""" + self._bound = address + + def listen(self, backlog: int): # pylint: disable=unused-argument + """Set socket to listen for incoming connections. + :param int backlog: length of backlog queue for waiting connections (ignored) + """ + if not self._bound: + self._bound = (self._interface.ip_address, 80) + port = self._bound[1] + self._interface.start_server(port, self._socknum) + + def setblocking(self, flag: bool): + """Set the blocking behaviour of this socket. + :param bool flag: False means non-blocking, True means block indefinitely. + """ + if flag: + self.settimeout(None) + else: + self.settimeout(0) + + def setsockopt(self, *opts, **kwopts): + """Dummy call for compatibility."""