diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index b3e485d4..d5a9f011 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -61,6 +61,7 @@ if t.TYPE_CHECKING: from ..._api import TelemetryAPI + from ...addressing import Address # Set up logger @@ -135,9 +136,12 @@ class AsyncBolt: # results for it. most_recent_qid = None + address_callback = None + advertised_address: Address | None = None + def __init__( self, - unresolved_address, + address, sock, max_connection_lifetime, *, @@ -149,12 +153,12 @@ def __init__( notifications_disabled_classifications=None, telemetry_disabled=False, ): - self.unresolved_address = unresolved_address + self._address = address self.socket = sock self.local_port = self.socket.getsockname()[1] self.server_info = ServerInfo( ResolvedAddress( - sock.getpeername(), host_name=unresolved_address.host + sock.getpeername(), host_name=address._unresolved.host ), self.PROTOCOL_VERSION, ) @@ -200,6 +204,15 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @property + def address(self): + return self._address + + @address.setter + def address(self, value): + self._address = value + self.server_info._address = value._unresolved + @abc.abstractmethod def _get_server_state_manager(self) -> ServerStateManagerBase: ... @@ -308,6 +321,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5, AsyncBolt5x6, AsyncBolt5x7, + AsyncBolt5x8, ) handlers = { @@ -325,6 +339,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, + AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8, } if protocol_version is None: @@ -461,7 +476,10 @@ async def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import AsyncBolt5x8 + bolt_cls = AsyncBolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import AsyncBolt5x7 bolt_cls = AsyncBolt5x7 elif protocol_version == (5, 6): @@ -954,12 +972,12 @@ async def send_all(self): if self.closed(): raise ServiceUnavailable( "Failed to write to closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self.defunct(): raise ServiceUnavailable( "Failed to write to defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) await self._send_all() @@ -977,12 +995,12 @@ async def fetch_message(self): if self._closed: raise ServiceUnavailable( "Failed to read from closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self._defunct: raise ServiceUnavailable( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if not self.responses: return 0, 0 @@ -1014,14 +1032,14 @@ async def fetch_all(self): async def _set_defunct_read(self, error=None, silent=False): message = ( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) await self._set_defunct(message, error=error, silent=silent) async def _set_defunct_write(self, error=None, silent=False): message = ( "Failed to write data to connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) await self._set_defunct(message, error=error, silent=silent) @@ -1060,7 +1078,7 @@ async def _set_defunct(self, message, error=None, silent=False): # connection again. await self.close() if self.pool and not self._get_server_state_manager().failed(): - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 08e75abb..55638074 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -579,12 +579,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -595,7 +595,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 202d5570..70066b3e 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -494,12 +494,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -511,7 +511,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 06336193..b1a36be8 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -24,6 +24,7 @@ from ..._codec.hydration import v2 as hydration_v2 from ..._exceptions import BoltProtocolError from ..._meta import BOLT_AGENT_DICT +from ...addressing import Address from ...api import ( READ_ACCESS, Version, @@ -496,12 +497,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -513,7 +514,7 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 @@ -1204,12 +1205,12 @@ async def _process_message(self, tag, fields): await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - await self.pool.deactivate(address=self.unresolved_address) + await self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -1221,7 +1222,51 @@ async def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 + + +class AsyncBolt5x8(AsyncBolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append( + b"\x6a", + (self.auth_dict,), + response=LogonResponse( + self, "logon", hydration_hooks, on_success=self._logon_success + ), + dehydration_hooks=dehydration_hooks, + ) + + async def _logon_success(self, meta: object) -> None: + if not isinstance(meta, dict): + log.warning( + "[#%04X] _: " + "LOGON expected dictionary metadata, got %r", + self.local_port, + meta, + ) + return + address = meta.get("advertised_address", ...) + if address is ...: + return + if not isinstance(address, str): + log.warning( + "[#%04X] _: " + "LOGON expected string advertised_address, got %r", + self.local_port, + address, + ) + return + self.advertised_address = Address.parse(address, default_port=7687) + await AsyncUtil.callback(self.address_callback, self) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 7a520abe..bd695e6d 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -119,7 +119,7 @@ async def _acquire_from_pool(self, address): return None # no free connection available def _remove_connection(self, connection): - address = connection.unresolved_address + address = connection.address with self.lock: log.debug( "[#%04X] _: remove connection from pool %r %s", @@ -133,6 +133,7 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._log_pool_stats() async def _acquire_from_pool_checked( self, address, health_check, deadline @@ -194,11 +195,13 @@ async def connection_creator(): self.connections_reservations[address] -= 1 released_reservation = True self.connections[address].append(connection) + self._log_pool_stats() return connection finally: if not released_reservation: with self.lock: self.connections_reservations[address] -= 1 + self._log_pool_stats() max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") @@ -210,6 +213,7 @@ async def connection_creator(): if infinite_pool_size or pool_size < max_pool_size: # there's room for a new connection self.connections_reservations[address] += 1 + self._log_pool_stats() return connection_creator return None @@ -347,7 +351,12 @@ async def health_check(connection_, deadline_): f"{deadline.original_timeout!r}s (timeout)" ) log.debug("[#0000] _: trying to hand out new connection") - return await connection_creator() + connection = await connection_creator() + await self._on_new_connection(connection) + return connection + + async def _on_new_connection(self, connection): + return @abc.abstractmethod async def acquire( @@ -497,6 +506,7 @@ async def deactivate(self, address): connections.remove(conn) if not self.connections[address]: del self.connections[address] + self._log_pool_stats() await self._close_connections(closable_connections) @@ -508,7 +518,7 @@ async def on_write_failure(self, address, database): async def on_neo4j_error(self, error, connection): assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): - address = connection.unresolved_address + address = connection.address log.debug( "[#0000] _: mark all connections to %r as " "unauthenticated", @@ -540,10 +550,30 @@ async def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + self._log_pool_stats() await self._close_connections(connections) except TypeError: pass + def _log_pool_stats(self): + level = logging.DEBUG + if log.isEnabledFor(level): + with self.lock: + addresses = sorted( + set(self.connections.keys()) + | set(self.connections_reservations.keys()) + ) + stats = { + address: { + "connections": len(self.connections.get(address, ())), + "reservations": self.connections_reservations.get( + address, 0 + ), + } + for address in addresses + } + log.log(level, "[#0000] _: stats %r", stats) + class AsyncBoltPool(AsyncIOPool): is_direct_pool = True @@ -855,6 +885,8 @@ async def _update_routing_table_from( ) if callable(database_callback): database_callback(new_database) + + await self.update_connection_pool(database=new_database) return True await self.deactivate(router) return False @@ -943,6 +975,9 @@ async def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") async def update_connection_pool(self, *, database): + log.debug( + "[#0000] _: update connection pool, database=%r", database + ) async with self.refresh_lock: routing_tables = [await self.get_or_create_routing_table(database)] for db in self.routing_tables: @@ -952,6 +987,11 @@ async def update_connection_pool(self, *, database): servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: + log.debug( + "[#0000] _: deactivating address (not used in any " + "routing table): %r", + address, + ) await super().deactivate(address) async def ensure_routing_table_is_fresh( @@ -1013,7 +1053,6 @@ async def ensure_routing_table_is_fresh( acquisition_timeout=acquisition_timeout, database_callback=database_callback, ) - await self.update_connection_pool(database=database) return True @@ -1046,6 +1085,10 @@ async def _select_address(self, *, access_mode, database): ) return choice(addresses_by_usage[min(addresses_by_usage)]) + async def _on_new_connection(self, connection): + await self._move_connection(connection) + connection.address_callback = self._move_connection + async def acquire( self, access_mode, @@ -1149,3 +1192,32 @@ async def on_write_failure(self, address, database): if table is not None: table.writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) + + async def _move_connection(self, connection): + to_addr = connection.advertised_address + if to_addr is None: + return + from_addr = connection.address + if from_addr == to_addr: + return + log.debug( + "[#%04X] _: moving connection from %r to %r", + connection.local_port, + from_addr, + to_addr, + ) + with self.lock: + old_pool = self.connections[from_addr] + new_pool = self.connections[to_addr] + try: + old_pool.remove(connection) + except ValueError: + log.debug( + "[#%04X] _: abort move (connection not in pool)", + connection.local_port, + ) + return + new_pool.append(connection) + connection.address = connection.advertised_address + self._log_pool_stats() + self.cond.notify_all() diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index ac62fa1f..1a5b00fc 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -123,7 +123,7 @@ def __init__( self._on_error = on_error self._on_closed = on_closed self._metadata: dict = {} - self._address: Address = self._connection.unresolved_address + self._address: Address = self._connection.address self._keys: tuple[str, ...] = () self._had_record = False self._record_buffer: deque[Record] = deque() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 2ad1790a..49c28c8f 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -61,6 +61,7 @@ if t.TYPE_CHECKING: from ..._api import TelemetryAPI + from ...addressing import Address # Set up logger @@ -135,9 +136,12 @@ class Bolt: # results for it. most_recent_qid = None + address_callback = None + advertised_address: Address | None = None + def __init__( self, - unresolved_address, + address, sock, max_connection_lifetime, *, @@ -149,12 +153,12 @@ def __init__( notifications_disabled_classifications=None, telemetry_disabled=False, ): - self.unresolved_address = unresolved_address + self._address = address self.socket = sock self.local_port = self.socket.getsockname()[1] self.server_info = ServerInfo( ResolvedAddress( - sock.getpeername(), host_name=unresolved_address.host + sock.getpeername(), host_name=address._unresolved.host ), self.PROTOCOL_VERSION, ) @@ -200,6 +204,15 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @property + def address(self): + return self._address + + @address.setter + def address(self, value): + self._address = value + self.server_info._address = value._unresolved + @abc.abstractmethod def _get_server_state_manager(self) -> ServerStateManagerBase: ... @@ -308,6 +321,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5, Bolt5x6, Bolt5x7, + Bolt5x8, ) handlers = { @@ -325,6 +339,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5.PROTOCOL_VERSION: Bolt5x5, Bolt5x6.PROTOCOL_VERSION: Bolt5x6, Bolt5x7.PROTOCOL_VERSION: Bolt5x7, + Bolt5x8.PROTOCOL_VERSION: Bolt5x8, } if protocol_version is None: @@ -461,7 +476,10 @@ def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import Bolt5x8 + bolt_cls = Bolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import Bolt5x7 bolt_cls = Bolt5x7 elif protocol_version == (5, 6): @@ -954,12 +972,12 @@ def send_all(self): if self.closed(): raise ServiceUnavailable( "Failed to write to closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self.defunct(): raise ServiceUnavailable( "Failed to write to defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) self._send_all() @@ -977,12 +995,12 @@ def fetch_message(self): if self._closed: raise ServiceUnavailable( "Failed to read from closed connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if self._defunct: raise ServiceUnavailable( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) if not self.responses: return 0, 0 @@ -1014,14 +1032,14 @@ def fetch_all(self): def _set_defunct_read(self, error=None, silent=False): message = ( "Failed to read from defunct connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) self._set_defunct(message, error=error, silent=silent) def _set_defunct_write(self, error=None, silent=False): message = ( "Failed to write data to connection " - f"{self.unresolved_address!r} ({self.server_info.address!r})" + f"{self.address!r} ({self.server_info.address!r})" ) self._set_defunct(message, error=error, silent=silent) @@ -1060,7 +1078,7 @@ def _set_defunct(self, message, error=None, silent=False): # connection again. self.close() if self.pool and not self._get_server_state_manager().failed(): - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index e3cfd142..9d67b6e2 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -579,12 +579,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -595,7 +595,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 69bb6dd6..6dc9e7cb 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -494,12 +494,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -511,7 +511,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 4138a9d5..d6a1a518 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -24,6 +24,7 @@ from ..._codec.hydration import v2 as hydration_v2 from ..._exceptions import BoltProtocolError from ..._meta import BOLT_AGENT_DICT +from ...addressing import Address from ...api import ( READ_ACCESS, Version, @@ -496,12 +497,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -513,7 +514,7 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 @@ -1204,12 +1205,12 @@ def _process_message(self, tag, fields): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address) + self.pool.deactivate(address=self.address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address, + address=self.address, database=self.last_database, ) raise @@ -1221,7 +1222,51 @@ def _process_message(self, tag, fields): sig_int = ord(summary_signature) raise BoltProtocolError( f"Unexpected response message with signature {sig_int:02X}", - self.unresolved_address, + self.address, ) return len(details), 1 + + +class Bolt5x8(Bolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append( + b"\x6a", + (self.auth_dict,), + response=LogonResponse( + self, "logon", hydration_hooks, on_success=self._logon_success + ), + dehydration_hooks=dehydration_hooks, + ) + + def _logon_success(self, meta: object) -> None: + if not isinstance(meta, dict): + log.warning( + "[#%04X] _: " + "LOGON expected dictionary metadata, got %r", + self.local_port, + meta, + ) + return + address = meta.get("advertised_address", ...) + if address is ...: + return + if not isinstance(address, str): + log.warning( + "[#%04X] _: " + "LOGON expected string advertised_address, got %r", + self.local_port, + address, + ) + return + self.advertised_address = Address.parse(address, default_port=7687) + Util.callback(self.address_callback, self) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 1570e745..9e851f2f 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -116,7 +116,7 @@ def _acquire_from_pool(self, address): return None # no free connection available def _remove_connection(self, connection): - address = connection.unresolved_address + address = connection.address with self.lock: log.debug( "[#%04X] _: remove connection from pool %r %s", @@ -130,6 +130,7 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._log_pool_stats() def _acquire_from_pool_checked( self, address, health_check, deadline @@ -191,11 +192,13 @@ def connection_creator(): self.connections_reservations[address] -= 1 released_reservation = True self.connections[address].append(connection) + self._log_pool_stats() return connection finally: if not released_reservation: with self.lock: self.connections_reservations[address] -= 1 + self._log_pool_stats() max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = max_pool_size < 0 or max_pool_size == float("inf") @@ -207,6 +210,7 @@ def connection_creator(): if infinite_pool_size or pool_size < max_pool_size: # there's room for a new connection self.connections_reservations[address] += 1 + self._log_pool_stats() return connection_creator return None @@ -344,7 +348,12 @@ def health_check(connection_, deadline_): f"{deadline.original_timeout!r}s (timeout)" ) log.debug("[#0000] _: trying to hand out new connection") - return connection_creator() + connection = connection_creator() + self._on_new_connection(connection) + return connection + + def _on_new_connection(self, connection): + return @abc.abstractmethod def acquire( @@ -494,6 +503,7 @@ def deactivate(self, address): connections.remove(conn) if not self.connections[address]: del self.connections[address] + self._log_pool_stats() self._close_connections(closable_connections) @@ -505,7 +515,7 @@ def on_write_failure(self, address, database): def on_neo4j_error(self, error, connection): assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): - address = connection.unresolved_address + address = connection.address log.debug( "[#0000] _: mark all connections to %r as " "unauthenticated", @@ -537,10 +547,30 @@ def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + self._log_pool_stats() self._close_connections(connections) except TypeError: pass + def _log_pool_stats(self): + level = logging.DEBUG + if log.isEnabledFor(level): + with self.lock: + addresses = sorted( + set(self.connections.keys()) + | set(self.connections_reservations.keys()) + ) + stats = { + address: { + "connections": len(self.connections.get(address, ())), + "reservations": self.connections_reservations.get( + address, 0 + ), + } + for address in addresses + } + log.log(level, "[#0000] _: stats %r", stats) + class BoltPool(IOPool): is_direct_pool = True @@ -852,6 +882,8 @@ def _update_routing_table_from( ) if callable(database_callback): database_callback(new_database) + + self.update_connection_pool(database=new_database) return True self.deactivate(router) return False @@ -940,6 +972,9 @@ def update_routing_table( raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): + log.debug( + "[#0000] _: update connection pool, database=%r", database + ) with self.refresh_lock: routing_tables = [self.get_or_create_routing_table(database)] for db in self.routing_tables: @@ -949,6 +984,11 @@ def update_connection_pool(self, *, database): servers = set.union(*(rt.servers() for rt in routing_tables)) for address in list(self.connections): if address._unresolved not in servers: + log.debug( + "[#0000] _: deactivating address (not used in any " + "routing table): %r", + address, + ) super().deactivate(address) def ensure_routing_table_is_fresh( @@ -1010,7 +1050,6 @@ def ensure_routing_table_is_fresh( acquisition_timeout=acquisition_timeout, database_callback=database_callback, ) - self.update_connection_pool(database=database) return True @@ -1043,6 +1082,10 @@ def _select_address(self, *, access_mode, database): ) return choice(addresses_by_usage[min(addresses_by_usage)]) + def _on_new_connection(self, connection): + self._move_connection(connection) + connection.address_callback = self._move_connection + def acquire( self, access_mode, @@ -1146,3 +1189,32 @@ def on_write_failure(self, address, database): if table is not None: table.writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) + + def _move_connection(self, connection): + to_addr = connection.advertised_address + if to_addr is None: + return + from_addr = connection.address + if from_addr == to_addr: + return + log.debug( + "[#%04X] _: moving connection from %r to %r", + connection.local_port, + from_addr, + to_addr, + ) + with self.lock: + old_pool = self.connections[from_addr] + new_pool = self.connections[to_addr] + try: + old_pool.remove(connection) + except ValueError: + log.debug( + "[#%04X] _: abort move (connection not in pool)", + connection.local_port, + ) + return + new_pool.append(connection) + connection.address = connection.advertised_address + self._log_pool_stats() + self.cond.notify_all() diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 27164cf8..cf4f1ce0 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -123,7 +123,7 @@ def __init__( self._on_error = on_error self._on_closed = on_closed self._metadata: dict = {} - self._address: Address = self._connection.unresolved_address + self._address: Address = self._connection.address self._keys: tuple[str, ...] = () self._had_record = False self._record_buffer: deque[Record] = deque() diff --git a/testkit/testkit.json b/testkit/testkit.json index 93190035..307f1091 100644 --- a/testkit/testkit.json +++ b/testkit/testkit.json @@ -1,6 +1,6 @@ { "testkit": { "uri": "https://github.com/neo4j-drivers/testkit.git", - "ref": "5.0" + "ref": "advertised-address" } } diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index bca7f0ca..d236ff11 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -59,6 +59,7 @@ "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, + "Feature:Bolt:5.8": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", @@ -78,6 +79,7 @@ "ConfHint:connection.recv_timeout_seconds": true, + "Backend:DNSResolver": true, "Backend:MockTime": true, "Backend:RTFetch": true, "Backend:RTForceUpdate": true diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 9bf96779..56ac79bd 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -55,7 +55,9 @@ def __init__(self, *args, **kwargs): self.attach_mock( mock.AsyncMock(spec=AsyncAuthManager), "auth_manager" ) - self.unresolved_address = next(iter(args), "localhost") + self.address = next(iter(args), "localhost") + self.advertised_address = None + self.address_callback = None self.callbacks = [] diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index b0ddbc96..5fbb49ab 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), + ((5, 8), "neo4j._async.io._bolt5.AsyncBolt5x8"), ), ) @mark_async_test @@ -181,7 +183,7 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) @@ -189,7 +191,7 @@ async def test_version_negotiation( async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py new file mode 100644 index 00000000..408e641a --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -0,0 +1,907 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x8 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize( + ("advertised_address", "expected_call"), + ( + (..., None), + (None, Warning), + (1.2, Warning), + ("example.com", neo4j.Address(("example.com", 7687))), + ("example.com:1234", neo4j.Address(("example.com", 1234))), + ), +) +@mark_async_test +async def test_address_callback( + advertised_address, expected_call, fake_socket_pair, caplog +): + cb_calls = [] + + async def cb(connection_): + assert connection_ is connection + assert connection.address == address + cb_calls.append(connection_.advertised_address) + + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + success_meta = {} + if advertised_address is not ...: + success_meta["advertised_address"] = advertised_address + await sockets.server.send_message(b"\x70", success_meta) + + connection = AsyncBolt5x8(address, sockets.client, 0) + connection.address_callback = cb + + connection.logon() + await connection.send_all() + + if type(expected_call) is type and issubclass(expected_call, Warning): + with caplog.at_level(logging.WARNING): + await connection.fetch_all() + warning_logs = [rec.message for rec in caplog.records] + assert len(warning_logs) == 1 + assert "NON-FATAL PROTOCOL VIOLATION" in warning_logs[0] + assert not cb_calls + return + + await connection.fetch_all() + + if expected_call is None: + assert not cb_calls + return + + assert cb_calls == [expected_call] + assert connection.address == address diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 80014266..77950813 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -25,6 +25,7 @@ WorkspaceConfig, ) from neo4j._deadline import Deadline +from neo4j.addressing import Address from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( ClientError, @@ -37,6 +38,8 @@ class AsyncFakeBoltPool(AsyncIOPool): is_direct_pool = False + __on_open = None + def __init__(self, connection_gen, address, *, auth=None, **config): self.buffered_connection_mocks = [] config["auth"] = static_auth(None) @@ -54,6 +57,8 @@ async def opener(addr, auth, timeout): else: mock = connection_gen() mock.address = addr + if self.__on_open is not None: + self.__on_open(mock) return mock super().__init__(opener, self.pool_config, self.workspace_config) @@ -273,3 +278,49 @@ async def test_liveness_check( cx1.reset.reset_mock() await pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.fixture +async def simple_pool_factory(async_fake_connection_generator): + pools = [] + + def factory(**config): + pool_ = AsyncFakeBoltPool( + async_fake_connection_generator, + ("127.0.0.1", 7687), + **config, + ) + pools.append(pool_) + return pool_ + + yield factory + + for pool in pools: + await pool.close() + + +async def test_configures_no_address_cb_on_connection(simple_pool_factory): + pool = simple_pool_factory() + cx = await pool.acquire("r", Deadline(3), "test_db", None, None, None) + + assert cx.address_callback is None + + +async def test_does_not_move_connection_to_advertised_address_after_open( + simple_pool_factory, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + pool = simple_pool_factory() + pool._AsyncFakeBoltPool__on_open = on_open + cx = await pool.acquire("r", Deadline(3), "test_db", None, None, None) + + # assert has been moved + assert cx.address == pool.address + assert len(pool.connections[pool.address]) == 1 + assert len(pool.connections[advertised_address]) == 0 + assert cx in pool.connections[pool.address] diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index c0be16ad..a1f20252 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -33,7 +33,10 @@ WorkspaceConfig, ) from neo4j._deadline import Deadline -from neo4j.addressing import ResolvedAddress +from neo4j.addressing import ( + Address, + ResolvedAddress, +) from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( Neo4jError, @@ -55,7 +58,7 @@ @pytest.fixture def custom_routing_opener(async_fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener(failures=None, get_readers=None, on_open=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) @@ -84,7 +87,7 @@ def routing_side_effect(*args, **kwargs): async def open_(addr, auth, timeout): connection = async_fake_connection_generator() - connection.unresolved_address = addr + connection.address = addr connection.timeout = timeout connection.auth = auth route_mock = mocker.AsyncMock() @@ -92,6 +95,10 @@ async def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if on_open is not None: + on_open(connection) + return connection failures = iter(failures or []) @@ -188,9 +195,9 @@ async def test_chooses_right_connection_type(opener, type_): ) await pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER1_ADDRESS + assert cx1.address == WRITER1_ADDRESS @mark_async_test @@ -206,7 +213,7 @@ async def test_reuses_connection(opener): @mark_async_test async def test_closes_stale_connections(opener, break_on_close): async def break_connection(): - await pool.deactivate(cx1.unresolved_address) + await pool.deactivate(cx1.address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() @@ -218,7 +225,7 @@ async def break_connection(): pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) and then # breaking when the pool tries to close the connection cx1.stale.return_value = True @@ -233,16 +240,16 @@ async def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True @@ -250,9 +257,9 @@ async def test_does_not_close_stale_connections_in_use(opener): await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] await pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -263,9 +270,9 @@ async def test_does_not_close_stale_connections_in_use(opener): await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx2.unresolved_address] + assert cx3.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx2.address] @mark_async_test @@ -314,7 +321,7 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( cx1 = await pool._acquire( READER1_ADDRESS, None, Deadline(30), liveness_timeout ) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -330,7 +337,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection( ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -367,7 +374,7 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -384,11 +391,11 @@ def liveness_side_effect(*args, **kwargs): READER1_ADDRESS, None, Deadline(30), liveness_timeout ) assert cx1 is not cx2 - assert cx1.unresolved_address == cx2.unresolved_address + assert cx1.address == cx2.address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx1.address] @pytest.mark.parametrize( @@ -412,8 +419,8 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS - assert cx2.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -439,8 +446,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_awaited_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_awaited_once() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx1.address] @mark_async_test @@ -701,7 +708,7 @@ def get_readers(database): opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS await pool.release(cx1) cx1.close.assert_not_called() @@ -712,7 +719,7 @@ def get_readers(database): readers["db1"] = [str(READER2_ADDRESS)] cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx2.unresolved_address == READER2_ADDRESS + assert cx2.address == READER2_ADDRESS cx1.close.assert_awaited_once() assert len(pool.connections[READER1_ADDRESS]) == 0 @@ -740,14 +747,14 @@ def get_readers(database): ) cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) await pool.release(cx1) - assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} + assert cx1.address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) await pool.release(cx2) - assert cx2.unresolved_address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() assert len(pool.connections[READER1_ADDRESS]) == 1 @@ -759,7 +766,7 @@ def get_readers(database): cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) await pool.release(cx3) - assert cx3.unresolved_address == READER3_ADDRESS + assert cx3.address == READER3_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() @@ -767,3 +774,32 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_async_test +async def test_configures_address_cb_on_connection(opener): + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + assert cx.address_callback == pool._move_connection + + +@mark_async_test +async def test_moves_connection_to_advertised_address_after_open( + custom_routing_opener, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + opener = custom_routing_opener(on_open=on_open) + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + # assert has been moved + assert cx.address == advertised_address + assert len(pool.connections[READER1_ADDRESS]) == 0 + assert len(pool.connections[advertised_address]) == 1 + assert cx in pool.connections[advertised_address] diff --git a/tests/unit/async_/test_conf.py b/tests/unit/async_/test_conf.py index 1902a5d6..c9d20c6e 100644 --- a/tests/unit/async_/test_conf.py +++ b/tests/unit/async_/test_conf.py @@ -36,7 +36,6 @@ AsyncClientCertificateProviders, ClientCertificate, ) -from neo4j.debug import watch from neo4j.exceptions import ConfigurationError from ..._async_compat import mark_async_test @@ -45,8 +44,6 @@ # python -m pytest tests/unit/test_conf.py -s -v -watch("neo4j") - test_pool_config = { "connection_timeout": 30.0, "keep_alive": True, diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 1291d95e..8f824883 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -149,7 +149,7 @@ def __init__( self.run_meta = run_meta self.summary_meta = summary_meta AsyncConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) - self.unresolved_address = None + self.address = None self._new_hydration_scope_called = False async def send_all(self): diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 93b2c243..63baf227 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -24,12 +24,9 @@ READ_ACCESS, WRITE_ACCESS, ) -from neo4j.debug import watch from neo4j.exceptions import ConfigurationError -watch("neo4j") - test_session_config = { "connection_acquisition_timeout": 60.0, "max_transaction_retry_time": 30.0, diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 74f2059a..46c85c53 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -890,6 +890,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 5), "t_first"), ((5, 6), "t_first"), ((5, 7), "t_first"), + ((5, 8), "t_first"), ), ) def test_summary_result_available_after( @@ -927,6 +928,7 @@ def test_summary_result_available_after( ((5, 5), "t_last"), ((5, 6), "t_last"), ((5, 7), "t_last"), + ((5, 8), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 8785badb..f66694ec 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -55,7 +55,9 @@ def __init__(self, *args, **kwargs): self.attach_mock( mock.MagicMock(spec=AuthManager), "auth_manager" ) - self.unresolved_address = next(iter(args), "localhost") + self.address = next(iter(args), "localhost") + self.advertised_address = None + self.address_callback = None self.callbacks = [] diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f3b06303..7c2fc7f0 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), + ((5, 8), "neo4j._sync.io._bolt5.Bolt5x8"), ), ) @mark_sync_test @@ -181,7 +183,7 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) @@ -189,7 +191,7 @@ def test_version_negotiation( def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8')" ) address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py new file mode 100644 index 00000000..09b83ab4 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -0,0 +1,907 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x8 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize( + ("advertised_address", "expected_call"), + ( + (..., None), + (None, Warning), + (1.2, Warning), + ("example.com", neo4j.Address(("example.com", 7687))), + ("example.com:1234", neo4j.Address(("example.com", 1234))), + ), +) +@mark_sync_test +def test_address_callback( + advertised_address, expected_call, fake_socket_pair, caplog +): + cb_calls = [] + + def cb(connection_): + assert connection_ is connection + assert connection.address == address + cb_calls.append(connection_.advertised_address) + + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + success_meta = {} + if advertised_address is not ...: + success_meta["advertised_address"] = advertised_address + sockets.server.send_message(b"\x70", success_meta) + + connection = Bolt5x8(address, sockets.client, 0) + connection.address_callback = cb + + connection.logon() + connection.send_all() + + if type(expected_call) is type and issubclass(expected_call, Warning): + with caplog.at_level(logging.WARNING): + connection.fetch_all() + warning_logs = [rec.message for rec in caplog.records] + assert len(warning_logs) == 1 + assert "NON-FATAL PROTOCOL VIOLATION" in warning_logs[0] + assert not cb_calls + return + + connection.fetch_all() + + if expected_call is None: + assert not cb_calls + return + + assert cb_calls == [expected_call] + assert connection.address == address diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index a899ae49..933efa0b 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -25,6 +25,7 @@ from neo4j._sync.config import PoolConfig from neo4j._sync.io import Bolt from neo4j._sync.io._pool import IOPool +from neo4j.addressing import Address from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( ClientError, @@ -37,6 +38,8 @@ class FakeBoltPool(IOPool): is_direct_pool = False + __on_open = None + def __init__(self, connection_gen, address, *, auth=None, **config): self.buffered_connection_mocks = [] config["auth"] = static_auth(None) @@ -54,6 +57,8 @@ def opener(addr, auth, timeout): else: mock = connection_gen() mock.address = addr + if self.__on_open is not None: + self.__on_open(mock) return mock super().__init__(opener, self.pool_config, self.workspace_config) @@ -273,3 +278,49 @@ def test_liveness_check( cx1.reset.reset_mock() pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.fixture +def simple_pool_factory(fake_connection_generator): + pools = [] + + def factory(**config): + pool_ = FakeBoltPool( + fake_connection_generator, + ("127.0.0.1", 7687), + **config, + ) + pools.append(pool_) + return pool_ + + yield factory + + for pool in pools: + pool.close() + + +def test_configures_no_address_cb_on_connection(simple_pool_factory): + pool = simple_pool_factory() + cx = pool.acquire("r", Deadline(3), "test_db", None, None, None) + + assert cx.address_callback is None + + +def test_does_not_move_connection_to_advertised_address_after_open( + simple_pool_factory, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + pool = simple_pool_factory() + pool._FakeBoltPool__on_open = on_open + cx = pool.acquire("r", Deadline(3), "test_db", None, None, None) + + # assert has been moved + assert cx.address == pool.address + assert len(pool.connections[pool.address]) == 1 + assert len(pool.connections[advertised_address]) == 0 + assert cx in pool.connections[pool.address] diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 89b4d16b..e7bb9475 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -33,7 +33,10 @@ Bolt, Neo4jPool, ) -from neo4j.addressing import ResolvedAddress +from neo4j.addressing import ( + Address, + ResolvedAddress, +) from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( Neo4jError, @@ -55,7 +58,7 @@ @pytest.fixture def custom_routing_opener(fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener(failures=None, get_readers=None, on_open=None): def routing_side_effect(*args, **kwargs): nonlocal failures res = next(failures, None) @@ -84,7 +87,7 @@ def routing_side_effect(*args, **kwargs): def open_(addr, auth, timeout): connection = fake_connection_generator() - connection.unresolved_address = addr + connection.address = addr connection.timeout = timeout connection.auth = auth route_mock = mocker.MagicMock() @@ -92,6 +95,10 @@ def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if on_open is not None: + on_open(connection) + return connection failures = iter(failures or []) @@ -188,9 +195,9 @@ def test_chooses_right_connection_type(opener, type_): ) pool.release(cx1) if type_ == "r": - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS else: - assert cx1.unresolved_address == WRITER1_ADDRESS + assert cx1.address == WRITER1_ADDRESS @mark_sync_test @@ -206,7 +213,7 @@ def test_reuses_connection(opener): @mark_sync_test def test_closes_stale_connections(opener, break_on_close): def break_connection(): - pool.deactivate(cx1.unresolved_address) + pool.deactivate(cx1.address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() @@ -218,7 +225,7 @@ def break_connection(): pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) and then # breaking when the pool tries to close the connection cx1.stale.return_value = True @@ -233,16 +240,16 @@ def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert cx1 in pool.connections[cx1.unresolved_address] + assert cx1 in pool.connections[cx1.address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True @@ -250,9 +257,9 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.unresolved_address == cx1.unresolved_address - assert cx1 in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx2.unresolved_address] + assert cx2.address == cx1.address + assert cx1 in pool.connections[cx1.address] + assert cx2 in pool.connections[cx2.address] pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -263,9 +270,9 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.unresolved_address == cx1.unresolved_address - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx2.unresolved_address] + assert cx3.address == cx1.address + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx2.address] @mark_sync_test @@ -314,7 +321,7 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( cx1 = pool._acquire( READER1_ADDRESS, None, Deadline(30), liveness_timeout ) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.reset.assert_not_called() @@ -330,7 +337,7 @@ def test_acquire_performs_liveness_check_on_existing_connection( ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -367,7 +374,7 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -384,11 +391,11 @@ def liveness_side_effect(*args, **kwargs): READER1_ADDRESS, None, Deadline(30), liveness_timeout ) assert cx1 is not cx2 - assert cx1.unresolved_address == cx2.unresolved_address + assert cx1.address == cx2.address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx2 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx2 in pool.connections[cx1.address] @pytest.mark.parametrize( @@ -412,8 +419,8 @@ def liveness_side_effect(*args, **kwargs): ) # make sure we assume the right state - assert cx1.unresolved_address == READER1_ADDRESS - assert cx2.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -439,8 +446,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_called_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_called_once() - assert cx1 not in pool.connections[cx1.unresolved_address] - assert cx3 in pool.connections[cx1.unresolved_address] + assert cx1 not in pool.connections[cx1.address] + assert cx3 in pool.connections[cx1.address] @mark_sync_test @@ -701,7 +708,7 @@ def get_readers(database): opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx1.unresolved_address == READER1_ADDRESS + assert cx1.address == READER1_ADDRESS pool.release(cx1) cx1.close.assert_not_called() @@ -712,7 +719,7 @@ def get_readers(database): readers["db1"] = [str(READER2_ADDRESS)] cx2 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) - assert cx2.unresolved_address == READER2_ADDRESS + assert cx2.address == READER2_ADDRESS cx1.close.assert_called_once() assert len(pool.connections[READER1_ADDRESS]) == 0 @@ -740,14 +747,14 @@ def get_readers(database): ) cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) pool.release(cx1) - assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} + assert cx1.address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 cx2 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) pool.release(cx2) - assert cx2.unresolved_address == READER1_ADDRESS + assert cx2.address == READER1_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() assert len(pool.connections[READER1_ADDRESS]) == 1 @@ -759,7 +766,7 @@ def get_readers(database): cx3 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) pool.release(cx3) - assert cx3.unresolved_address == READER3_ADDRESS + assert cx3.address == READER3_ADDRESS cx1.close.assert_not_called() cx2.close.assert_not_called() @@ -767,3 +774,32 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_sync_test +def test_configures_address_cb_on_connection(opener): + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + assert cx.address_callback == pool._move_connection + + +@mark_sync_test +def test_moves_connection_to_advertised_address_after_open( + custom_routing_opener, +): + advertised_address = Address(("example.com", 1234)) + + def on_open(connection): + assert connection.address != advertised_address # sanity check + connection.advertised_address = advertised_address + + opener = custom_routing_opener(on_open=on_open) + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + + # assert has been moved + assert cx.address == advertised_address + assert len(pool.connections[READER1_ADDRESS]) == 0 + assert len(pool.connections[advertised_address]) == 1 + assert cx in pool.connections[advertised_address] diff --git a/tests/unit/sync/test_conf.py b/tests/unit/sync/test_conf.py index 65481811..bc589a38 100644 --- a/tests/unit/sync/test_conf.py +++ b/tests/unit/sync/test_conf.py @@ -36,7 +36,6 @@ ClientCertificate, ClientCertificateProviders, ) -from neo4j.debug import watch from neo4j.exceptions import ConfigurationError from ..._async_compat import mark_sync_test @@ -45,8 +44,6 @@ # python -m pytest tests/unit/test_conf.py -s -v -watch("neo4j") - test_pool_config = { "connection_timeout": 30.0, "keep_alive": True, diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 623d5014..c9a1bb48 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -149,7 +149,7 @@ def __init__( self.run_meta = run_meta self.summary_meta = summary_meta ConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) - self.unresolved_address = None + self.address = None self._new_hydration_scope_called = False def send_all(self):