From bb208e7ddcd72d0e704a74fe1f243c4bf12b6258 Mon Sep 17 00:00:00 2001 From: mhh Date: Thu, 5 Oct 2023 19:13:46 +0200 Subject: [PATCH 1/6] Refactor: Duplication in parameters across get_messages and get_posts methods; Client classes are bloated and have unused parameters Solution: add MessageFilter & PostFilter; remove limit parameter; --- src/aleph/sdk/base.py | 164 +++--------------- src/aleph/sdk/client.py | 255 ++++------------------------ src/aleph/sdk/exceptions.py | 4 +- src/aleph/sdk/models.py | 51 ------ src/aleph/sdk/models/__init__.py | 0 src/aleph/sdk/models/common.py | 29 ++++ src/aleph/sdk/models/message.py | 102 +++++++++++ src/aleph/sdk/models/post.py | 122 +++++++++++++ tests/integration/itest_forget.py | 14 +- tests/unit/conftest.py | 74 +++++++- tests/unit/test_asynchronous_get.py | 15 +- tests/unit/test_chain_ethereum.py | 4 +- tests/unit/test_chain_solana.py | 4 +- tests/unit/test_synchronous_get.py | 6 +- 14 files changed, 412 insertions(+), 432 deletions(-) delete mode 100644 src/aleph/sdk/models.py create mode 100644 src/aleph/sdk/models/__init__.py create mode 100644 src/aleph/sdk/models/common.py create mode 100644 src/aleph/sdk/models/message.py create mode 100644 src/aleph/sdk/models/post.py diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/base.py index a5b2c266..9d80bc64 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/base.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime from pathlib import Path from typing import ( Any, @@ -26,42 +25,33 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum +from .models.message import MessageFilter +from .models.post import PostFilter, PostsResponse +from .types import GenericMessage, StorageEnum DEFAULT_PAGE_SIZE = 200 class BaseAlephClient(ABC): @abstractmethod - async def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: """ Fetch a value from the aggregate store by owner address and item key. :param address: Address of the owner of the aggregate :param key: Key of the aggregate - :param limit: Maximum number of items to fetch (Default: 100) """ pass @abstractmethod async def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: """ Fetch key-value pairs from the aggregate store by owner address. :param address: Address of the owner of the aggregate :param keys: Keys of the aggregates to fetch (Default: all items) - :param limit: Maximum number of items to fetch (Default: 100) """ pass @@ -70,15 +60,7 @@ async def get_posts( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -87,15 +69,7 @@ async def get_posts( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -103,44 +77,20 @@ async def get_posts( async def get_posts_iterator( self, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) """ page = 1 resp = None while resp is None or len(resp.posts) > 0: resp = await self.get_posts( page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) page += 1 for post in resp.posts: @@ -165,18 +115,7 @@ async def get_messages( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -185,18 +124,7 @@ async def get_messages( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: [DEPRECATED] Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param message_types: Filter by message types, can be any combination of "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by aggregate key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -204,50 +132,20 @@ async def get_messages( async def get_messages_iterator( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages """ page = 1 resp = None while resp is None or len(resp.messages) > 0: resp = await self.get_messages( page=page, - message_type=message_type, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ) page += 1 for message in resp.messages: @@ -272,34 +170,12 @@ async def get_message( @abstractmethod def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. - :param message_type: [DEPRECATED] Type of message to watch - :param message_types: Types of messages to watch - :param content_types: Content types to watch - :param content_keys: Filter by aggregate key - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch + :param message_filter: Filter to apply to the messages """ pass @@ -318,7 +194,7 @@ async def create_post( sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ - Create a POST message on the Aleph network. It is associated with a channel and owned by an account. + Create a POST message on the aleph.im network. It is associated with a channel and owned by an account. :param post_content: The content of the message :param post_type: An arbitrary content type that helps to describe the post_content @@ -368,7 +244,7 @@ async def create_store( sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ - Create a STORE message to store a file on the Aleph network. + Create a STORE message to store a file on the aleph.im network. Can be passed either a file path, an IPFS hash or the file's content as raw bytes. @@ -422,7 +298,7 @@ async def create_program( :param persistent: Whether the program should be persistent or not (Default: False) :param encoding: Encoding to use (Default: Encoding.zip) :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver + :param subscriptions: Patterns of aleph.im messages to forward to the program's event receiver :param metadata: Metadata to attach to the message """ pass diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index f79f0ceb..837811b7 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -5,8 +5,6 @@ import queue import threading import time -import warnings -from datetime import datetime from io import BytesIO from pathlib import Path from typing import ( @@ -61,7 +59,8 @@ MessageNotFoundError, MultipleMessagesError, ) -from .models import MessagesResponse, Post, PostsResponse +from .models.message import MessageFilter, MessagesResponse +from .models.post import Post, PostFilter, PostsResponse from .utils import check_unix_socket_valid, get_message_type_value logger = logging.getLogger(__name__) @@ -141,18 +140,7 @@ def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[List[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: bool = True, invalid_messages_log_level: int = logging.NOTSET, ) -> MessagesResponse: @@ -160,18 +148,7 @@ def get_messages( self.async_session.get_messages, pagination=pagination, page=page, - message_type=message_type, - message_types=message_types, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ignore_invalid_messages=ignore_invalid_messages, invalid_messages_log_level=invalid_messages_log_level, ) @@ -210,29 +187,13 @@ def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> PostsResponse: return self._wrap( self.async_session.get_posts, pagination=pagination, page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) def download_file(self, file_hash: str) -> bytes: @@ -246,7 +207,7 @@ def download_file_ipfs(self, file_hash: str) -> bytes: def download_file_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_to_buffer, file_hash=file_hash, @@ -255,7 +216,7 @@ def download_file_to_buffer( def download_file_ipfs_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_ipfs_to_buffer, file_hash=file_hash, @@ -264,16 +225,7 @@ def download_file_ipfs_to_buffer( def watch_messages( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> Iterable[AlephMessage]: """ Iterate over current and future matching messages synchronously. @@ -286,18 +238,7 @@ def watch_messages( args=( output_queue, self.async_session.api_server, - ( - message_type, - content_types, - refs, - addresses, - tags, - hashes, - channels, - chains, - start_date, - end_date, - ), + message_filter, {}, ), ) @@ -528,15 +469,8 @@ async def __aenter__(self) -> "AlephClient": async def __aexit__(self, exc_type, exc_val, exc_tb): await self.http_session.close() - async def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: params: Dict[str, Any] = {"keys": key} - if limit: - params["limit"] = limit async with self.http_session.get( f"/api/v0/aggregates/{address}.json", params=params @@ -546,17 +480,12 @@ async def fetch_aggregate( return data.get(key) async def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: keys_str = ",".join(keys) if keys else "" params: Dict[str, Any] = {} if keys_str: params["keys"] = keys_str - if limit: - params["limit"] = limit async with self.http_session.get( f"/api/v0/aggregates/{address}.json", @@ -570,15 +499,7 @@ async def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -591,31 +512,11 @@ async def get_posts( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not post_filter: + post_filter = PostFilter() + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get("/api/v0/posts.json", params=params) as resp: resp.raise_for_status() @@ -722,18 +623,7 @@ async def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -746,43 +636,11 @@ async def get_messages( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - print(params["msgTypes"]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get( "/api/v0/messages.json", params=params ) as resp: @@ -825,8 +683,10 @@ async def get_message( channel: Optional[str] = None, ) -> GenericMessage: messages_response = await self.get_messages( - hashes=[item_hash], - channels=[channel] if channel else None, + message_filter=MessageFilter( + hashes=[item_hash], + channels=[channel] if channel else None, + ) ) if len(messages_response.messages) < 1: raise MessageNotFoundError(f"No such hash {item_hash}") @@ -846,54 +706,11 @@ async def get_message( async def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: - params: Dict[str, Any] = dict() - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() async with self.http_session.ws_connect( "/api/ws0/messages", params=params @@ -1059,7 +876,7 @@ async def _handle_broadcast_deprecated_response( async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: """ - Broadcast a message on the Aleph network using the deprecated + Broadcast a message on the aleph.im network using the deprecated /ipfs/pubsub/pub/ endpoint. """ @@ -1097,7 +914,7 @@ async def _broadcast( sync: bool, ) -> MessageStatus: """ - Broadcast a message on the Aleph network. + Broadcast a message on the aleph.im network. Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint if the first method is not available. @@ -1273,7 +1090,7 @@ async def create_program( # Register the different ways to trigger a VM if subscriptions: - # Trigger on HTTP calls and on Aleph message subscriptions. + # Trigger on HTTP calls and on aleph.im message subscriptions. triggers = { "http": True, "persistent": persistent, @@ -1309,7 +1126,7 @@ async def create_program( "runtime": { "ref": runtime, "use_latest": True, - "comment": "Official Aleph runtime" + "comment": "Official aleph.im runtime" if runtime == settings.DEFAULT_RUNTIME_ID else "", }, diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 51762925..5f09e1bc 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -21,7 +21,7 @@ class MultipleMessagesError(QueryError): class BroadcastError(Exception): """ - Data could not be broadcast to the Aleph network. + Data could not be broadcast to the aleph.im network. """ pass @@ -29,7 +29,7 @@ class BroadcastError(Exception): class InvalidMessageError(BroadcastError): """ - The message could not be broadcast because it does not follow the Aleph + The message could not be broadcast because it does not follow the aleph.im message specification. """ diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py deleted file mode 100644 index f5b1072b..00000000 --- a/src/aleph/sdk/models.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -from aleph_message.models import AlephMessage, BaseMessage, ChainRef, ItemHash -from pydantic import BaseModel, Field - - -class PaginationResponse(BaseModel): - pagination_page: int - pagination_total: int - pagination_per_page: int - pagination_item: str - - -class MessagesResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] - pagination_item = "messages" - - -class Post(BaseMessage): - """ - A post is a type of message that can be updated. Over the get_posts API - we get the latest version of a post. - """ - - hash: ItemHash = Field(description="Hash of the content (sha256 by default)") - original_item_hash: ItemHash = Field( - description="Hash of the original content (sha256 by default)" - ) - original_signature: Optional[str] = Field( - description="Cryptographic signature of the original message by the sender" - ) - original_type: str = Field( - description="The original, user-generated 'content-type' of the POST message" - ) - content: Dict[str, Any] = Field( - description="The content.content of the POST message" - ) - type: str = Field(description="The content.type of the POST message") - address: str = Field(description="The address of the sender of the POST message") - ref: Optional[Union[str, ChainRef]] = Field( - description="Other message referenced by this one" - ) - - -class PostsResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/posts.json""" - - posts: List[Post] - pagination_item = "posts" diff --git a/src/aleph/sdk/models/__init__.py b/src/aleph/sdk/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/models/common.py b/src/aleph/sdk/models/common.py new file mode 100644 index 00000000..c7e0dc30 --- /dev/null +++ b/src/aleph/sdk/models/common.py @@ -0,0 +1,29 @@ +from datetime import datetime +from typing import Iterable, Optional, Union + +from pydantic import BaseModel + + +class PaginationResponse(BaseModel): + pagination_page: int + pagination_total: int + pagination_per_page: int + pagination_item: str + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/src/aleph/sdk/models/message.py b/src/aleph/sdk/models/message.py new file mode 100644 index 00000000..4ba6a1b2 --- /dev/null +++ b/src/aleph/sdk/models/message.py @@ -0,0 +1,102 @@ +from datetime import datetime +from typing import Dict, Iterable, List, Optional, Union + +from aleph_message.models import AlephMessage, MessageType + +from .common import PaginationResponse, _date_field_to_float, serialize_list + + +class MessagesResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" + + +class MessageFilter: + """ + A collection of filters that can be applied on message queries. + :param message_types: Filter by message type + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_types: Optional[Iterable[MessageType]] + content_types: Optional[Iterable[str]] + content_keys: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + message_types: Optional[Iterable[MessageType]] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.message_types = message_types + self.content_types = content_types + self.content_keys = content_keys + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "msgType": serialize_list( + [type.value for type in self.message_types] + if self.message_types + else None + ), + "contentTypes": serialize_list(self.content_types), + "contentKeys": serialize_list(self.content_keys), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/src/aleph/sdk/models/post.py b/src/aleph/sdk/models/post.py new file mode 100644 index 00000000..09a301c2 --- /dev/null +++ b/src/aleph/sdk/models/post.py @@ -0,0 +1,122 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + +from aleph_message.models import Chain, ItemHash, ItemType, MessageConfirmation +from pydantic import BaseModel, Field + +from .common import PaginationResponse, _date_field_to_float, serialize_list + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + chain: Chain = Field(description="Blockchain this post is associated with") + item_hash: ItemHash = Field(description="Unique hash for this post") + sender: str = Field(description="Address of the sender") + type: str = Field(description="Type of the POST message") + channel: Optional[str] = Field(description="Channel this post is associated with") + confirmed: bool = Field(description="Whether the post is confirmed or not") + content: Dict[str, Any] = Field(description="The content of the POST message") + item_content: Optional[str] = Field( + description="The POSTs content field as serialized JSON, if of type inline" + ) + item_type: ItemType = Field( + description="Type of the item content, usually 'inline' or 'storage' for POSTs" + ) + signature: Optional[str] = Field( + description="Cryptographic signature of the message by the sender" + ) + size: int = Field(description="Size of the post") + time: float = Field(description="Timestamp of the post") + confirmations: List[MessageConfirmation] = Field( + description="Number of confirmations" + ) + original_item_hash: ItemHash = Field(description="Hash of the original content") + original_signature: Optional[str] = Field( + description="Cryptographic signature of the original message" + ) + original_type: str = Field(description="The original type of the message") + hash: ItemHash = Field(description="Hash of the original item") + ref: Optional[Union[str, Any]] = Field( + description="Other message referenced by this one" + ) + + class Config: + allow_extra = False + + +class PostsResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" + + +class PostFilter: + """ + A collection of filters that can be applied on post queries. + + """ + + types: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.types = types + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "types": serialize_list(self.types), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 29b6c6d9..6b44f76f 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -1,8 +1,10 @@ from typing import Callable, Dict import pytest +from aleph_message.models import PostMessage from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.models.message import MessageFilter from aleph.sdk.types import Account from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL @@ -106,11 +108,9 @@ async def test_forget_a_forget_message(fixture_account): async with AuthenticatedAlephClient( account=fixture_account, api_server=TARGET_NODE ) as session: - get_post_response = await session.get_posts(hashes=[post_hash]) - assert len(get_post_response.posts) == 1 - post = get_post_response.posts[0] + get_post_message: PostMessage = await session.get_message(post_hash) - forget_message_hash = post.forgotten_by[0] + forget_message_hash = get_post_message.forgotten_by[0] forget_message, forget_status = await session.forget( hashes=[forget_message_hash], reason="I want to remember this post. Maybe I can forget I forgot it?", @@ -120,8 +120,10 @@ async def test_forget_a_forget_message(fixture_account): print(forget_message) get_forget_message_response = await session.get_messages( - hashes=[forget_message_hash], - channels=[TEST_CHANNEL], + message_filter=MessageFilter( + hashes=[forget_message_hash], + channels=[TEST_CHANNEL], + ) ) assert len(get_forget_message_response.messages) == 1 forget_message = get_forget_message_response.messages[0] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4f62c0c5..a51b1483 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,10 @@ import json from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -46,7 +48,77 @@ def substrate_account() -> substrate.DOTAccount: @pytest.fixture -def messages(): +def json_messages(): messages_path = Path(__file__).parent / "messages.json" with open(messages_path) as f: return json.load(f) + + +@pytest.fixture +def aleph_messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]: + return lambda page: { + "messages": [message.dict() for message in aleph_messages] + if int(page) == 1 + else [], + "pagination_item": "messages", + "pagination_page": int(page), + "pagination_per_page": max(len(aleph_messages), 20), + "pagination_total": len(aleph_messages) if page == 1 else 0, + } diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index db788e0b..72c47706 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -3,11 +3,12 @@ from unittest.mock import AsyncMock import pytest -from aleph_message.models import MessagesResponse +from aleph_message.models import MessagesResponse, MessageType from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings -from aleph.sdk.models import PostsResponse +from aleph.sdk.models.message import MessageFilter +from aleph.sdk.models.post import PostFilter, PostsResponse def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient: @@ -67,7 +68,12 @@ async def test_fetch_aggregates(): @pytest.mark.asyncio async def test_get_posts(): async with AlephClient(api_server=settings.API_HOST) as session: - response: PostsResponse = await session.get_posts() + response: PostsResponse = await session.get_posts( + pagination=2, + post_filter=PostFilter( + channels=["TEST"], + ), + ) posts = response.posts assert len(posts) > 1 @@ -78,6 +84,9 @@ async def test_get_messages(): async with AlephClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( pagination=2, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index dea58c69..9a602b3d 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -82,8 +82,8 @@ async def test_verify_signature(ethereum_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(ethereum_account, messages): - message = messages[1] +async def test_verify_signature_with_processed_message(ethereum_account, json_messages): + message = json_messages[1] verify_signature( message["signature"], message["sender"], get_verification_buffer(message) ) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 5088158a..07b67602 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -103,8 +103,8 @@ async def test_verify_signature(solana_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(solana_account, messages): - message = messages[0] +async def test_verify_signature_with_processed_message(solana_account, json_messages): + message = json_messages[0] signature = json.loads(message["signature"])["signature"] verify_signature(signature, message["sender"], get_verification_buffer(message)) diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index eee26dcf..0788a1ab 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -2,14 +2,16 @@ from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings +from aleph.sdk.models.message import MessageFilter def test_get_post_messages(): with AlephClient(api_server=settings.API_HOST) as session: - # TODO: Remove deprecated message_type parameter after message_types changes on pyaleph are deployed response: MessagesResponse = session.get_messages( pagination=2, - message_type=MessageType.post, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages From 347b0e21a67b5593f35aed93f9bcae1acebcbb98 Mon Sep 17 00:00:00 2001 From: mhh Date: Fri, 6 Oct 2023 12:43:10 +0200 Subject: [PATCH 2/6] Fix: Integration tests fail after changes to client and node behavior; aggregate fetching could result in unhandled parsing errors Solution: Adjust integration tests for current functionality; add raise_for_status to GET aggregates requests --- src/aleph/sdk/client.py | 2 + tests/integration/config.py | 2 +- tests/integration/itest_forget.py | 110 ++++++++++++++-------------- tests/integration/itest_posts.py | 21 +++--- tests/integration/toolkit.py | 12 ++- tests/unit/test_asynchronous_get.py | 3 + 6 files changed, 83 insertions(+), 67 deletions(-) diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index 837811b7..1c1b0f5e 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -475,6 +475,7 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: async with self.http_session.get( f"/api/v0/aggregates/{address}.json", params=params ) as resp: + resp.raise_for_status() result = await resp.json() data = result.get("data", dict()) return data.get(key) @@ -491,6 +492,7 @@ async def fetch_aggregates( f"/api/v0/aggregates/{address}.json", params=params, ) as resp: + resp.raise_for_status() result = await resp.json() data = result.get("data", dict()) return data diff --git a/tests/integration/config.py b/tests/integration/config.py index 4ec95a27..3e613c18 100644 --- a/tests/integration/config.py +++ b/tests/integration/config.py @@ -1,3 +1,3 @@ -TARGET_NODE = "http://163.172.70.92:4024" +TARGET_NODE = "https://api1.aleph.im" REFERENCE_NODE = "https://api2.aleph.im" TEST_CHANNEL = "INTEGRATION_TESTS" diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 6b44f76f..a457bdda 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -1,34 +1,20 @@ -from typing import Callable, Dict +import asyncio +from typing import Tuple import pytest -from aleph_message.models import PostMessage +from aleph_message.models import ItemHash from aleph.sdk.client import AuthenticatedAlephClient from aleph.sdk.models.message import MessageFilter from aleph.sdk.types import Account from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL -from .toolkit import try_until +from .toolkit import has_messages, has_no_messages, try_until async def create_and_forget_post( account: Account, emitter_node: str, receiver_node: str, channel=TEST_CHANNEL -) -> str: - async def wait_matching_posts( - item_hash: str, - condition: Callable[[Dict], bool], - timeout: int = 5, - ): - async with AuthenticatedAlephClient( - account=account, api_server=receiver_node - ) as rx_session: - return await try_until( - rx_session.get_posts, - condition, - timeout=timeout, - hashes=[item_hash], - ) - +) -> Tuple[ItemHash, ItemHash]: async with AuthenticatedAlephClient( account=account, api_server=emitter_node ) as tx_session: @@ -38,13 +24,17 @@ async def wait_matching_posts( channel="INTEGRATION_TESTS", ) - # Wait for the message to appear on the receiver. We don't check the values, - # they're checked in other integration tests. - get_post_response = await wait_matching_posts( - post_message.item_hash, - lambda response: len(response["posts"]) > 0, - ) - print(get_post_response) + async with AuthenticatedAlephClient( + account=account, api_server=receiver_node + ) as rx_session: + await try_until( + rx_session.get_messages, + has_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[post_message.item_hash], + ), + ) post_hash = post_message.item_hash reason = "This well thought-out content offends me!" @@ -56,27 +46,34 @@ async def wait_matching_posts( reason=reason, channel=channel, ) - assert forget_message.sender == account.get_address() assert forget_message.content.reason == reason assert forget_message.content.hashes == [post_hash] - - print(forget_message) + forget_hash = forget_message.item_hash # Wait until the message is forgotten - forgotten_posts = await wait_matching_posts( - post_hash, - lambda response: "forgotten_by" in response["posts"][0], - timeout=15, - ) + async with AuthenticatedAlephClient( + account=account, api_server=receiver_node + ) as rx_session: + await try_until( + rx_session.get_messages, + has_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[forget_hash], + ), + ) - assert len(forgotten_posts["posts"]) == 1 - forgotten_post = forgotten_posts["posts"][0] - assert forgotten_post["forgotten_by"] == [forget_message.item_hash] - assert forgotten_post["item_content"] is None - print(forgotten_post) + await try_until( + rx_session.get_messages, + has_no_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[post_hash], + ), + ) - return post_hash + return post_hash, forget_hash @pytest.mark.asyncio @@ -85,7 +82,7 @@ async def test_create_and_forget_post_on_target(fixture_account): Create a post on the target node, then forget it and check that the change is propagated to the reference node. """ - _ = await create_and_forget_post(fixture_account, TARGET_NODE, REFERENCE_NODE) + _, _ = await create_and_forget_post(fixture_account, TARGET_NODE, REFERENCE_NODE) @pytest.mark.asyncio @@ -94,7 +91,7 @@ async def test_create_and_forget_post_on_reference(fixture_account): Create a post on the reference node, then forget it and check that the change is propagated to the target node. """ - _ = await create_and_forget_post(fixture_account, REFERENCE_NODE, TARGET_NODE) + _, _ = await create_and_forget_post(fixture_account, REFERENCE_NODE, TARGET_NODE) @pytest.mark.asyncio @@ -104,26 +101,33 @@ async def test_forget_a_forget_message(fixture_account): """ # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. - post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) + post_hash, forget_hash = await create_and_forget_post( + fixture_account, TARGET_NODE, REFERENCE_NODE + ) async with AuthenticatedAlephClient( account=fixture_account, api_server=TARGET_NODE - ) as session: - get_post_message: PostMessage = await session.get_message(post_hash) - - forget_message_hash = get_post_message.forgotten_by[0] - forget_message, forget_status = await session.forget( - hashes=[forget_message_hash], + ) as tx_session: + forget_message, forget_status = await tx_session.forget( + hashes=[forget_hash], reason="I want to remember this post. Maybe I can forget I forgot it?", channel=TEST_CHANNEL, ) print(forget_message) - get_forget_message_response = await session.get_messages( + # wait 5 seconds + await asyncio.sleep(5) + + async with AuthenticatedAlephClient( + account=fixture_account, api_server=REFERENCE_NODE + ) as rx_session: + get_forget_message_response = await try_until( + rx_session.get_messages, + has_messages, + timeout=5, message_filter=MessageFilter( - hashes=[forget_message_hash], - channels=[TEST_CHANNEL], - ) + hashes=[forget_hash], + ), ) assert len(get_forget_message_response.messages) == 1 forget_message = get_forget_message_response.messages[0] diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index f30dc2b6..59b96b1b 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -1,20 +1,18 @@ import pytest -from aleph_message.models import MessagesResponse from aleph.sdk.client import AuthenticatedAlephClient -from tests.integration.toolkit import try_until +from aleph.sdk.models.message import MessageFilter +from tests.integration.toolkit import has_messages, try_until from .config import REFERENCE_NODE, TARGET_NODE -async def create_message_on_target( - fixture_account, emitter_node: str, receiver_node: str -): +async def create_message_on_target(account, emitter_node: str, receiver_node: str): """ Create a POST message on the target node, then fetch it from the reference node. """ async with AuthenticatedAlephClient( - account=fixture_account, api_server=emitter_node + account=account, api_server=emitter_node ) as tx_session: post_message, message_status = await tx_session.create_post( post_content=None, @@ -22,17 +20,16 @@ async def create_message_on_target( channel="INTEGRATION_TESTS", ) - def response_contains_messages(response: MessagesResponse) -> bool: - return len(response.messages) > 0 - async with AuthenticatedAlephClient( - account=fixture_account, api_server=receiver_node + account=account, api_server=receiver_node ) as rx_session: responses = await try_until( rx_session.get_messages, - response_contains_messages, + has_messages, timeout=5, - hashes=[post_message.item_hash], + message_filter=MessageFilter( + hashes=[post_message.item_hash], + ), ) message_from_target = responses.messages[0] diff --git a/tests/integration/toolkit.py b/tests/integration/toolkit.py index 70bc3bbb..62a5f841 100644 --- a/tests/integration/toolkit.py +++ b/tests/integration/toolkit.py @@ -2,6 +2,8 @@ import time from typing import Awaitable, Callable, TypeVar +from aleph.sdk.models.message import MessagesResponse + T = TypeVar("T") @@ -9,7 +11,7 @@ async def try_until( coroutine: Callable[..., Awaitable[T]], condition: Callable[[T], bool], timeout: float, - time_between_attempts: float = 0.5, + time_between_attempts: float = 1, *args, **kwargs, ) -> T: @@ -23,3 +25,11 @@ async def try_until( await asyncio.sleep(time_between_attempts) else: raise TimeoutError(f"No success in {timeout} seconds.") + + +def has_messages(response: MessagesResponse) -> bool: + return len(response.messages) > 0 + + +def has_no_messages(response: MessagesResponse) -> bool: + return len(response.messages) == 0 diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 72c47706..2db88ef3 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -23,6 +23,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def status(self): return 200 + def raise_for_status(self): + ... + async def json(self): return get_return_value From b66993a49925a7d36da60a217115dc6831fcb9de Mon Sep 17 00:00:00 2001 From: mhh Date: Fri, 6 Oct 2023 19:10:16 +0200 Subject: [PATCH 3/6] Refactor: client.py became too large/difficult to understand; .py files and their contained functions did not correspond to another Solution: refactor client.py into own module; rename pagination to page_size where possible; refactor models module --- src/aleph/sdk/client/__init__.py | 12 + .../{client.py => client/authenticated.py} | 536 +----------------- src/aleph/sdk/{ => client}/base.py | 14 +- src/aleph/sdk/client/client.py | 499 ++++++++++++++++ src/aleph/sdk/client/utils.py | 21 + src/aleph/sdk/models/common.py | 21 - src/aleph/sdk/models/message.py | 3 +- src/aleph/sdk/models/post.py | 3 +- src/aleph/sdk/models/utils.py | 20 + tests/unit/test_asynchronous_get.py | 6 +- tests/unit/test_synchronous_get.py | 4 +- 11 files changed, 575 insertions(+), 564 deletions(-) create mode 100644 src/aleph/sdk/client/__init__.py rename src/aleph/sdk/{client.py => client/authenticated.py} (56%) rename src/aleph/sdk/{ => client}/base.py (97%) create mode 100644 src/aleph/sdk/client/client.py create mode 100644 src/aleph/sdk/client/utils.py create mode 100644 src/aleph/sdk/models/utils.py diff --git a/src/aleph/sdk/client/__init__.py b/src/aleph/sdk/client/__init__.py new file mode 100644 index 00000000..8b0db873 --- /dev/null +++ b/src/aleph/sdk/client/__init__.py @@ -0,0 +1,12 @@ +from .authenticated import AuthenticatedAlephClient, AuthenticatedUserSessionSync +from .base import BaseAlephClient, BaseAuthenticatedAlephClient +from .client import AlephClient, UserSessionSync + +__all__ = [ + "BaseAlephClient", + "BaseAuthenticatedAlephClient", + "AlephClient", + "AuthenticatedAlephClient", + "UserSessionSync", + "AuthenticatedUserSessionSync", +] diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client/authenticated.py similarity index 56% rename from src/aleph/sdk/client.py rename to src/aleph/sdk/client/authenticated.py index 1c1b0f5e..093dbe76 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client/authenticated.py @@ -1,37 +1,18 @@ -import asyncio import hashlib import json import logging -import queue -import threading import time -from io import BytesIO from pathlib import Path -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Dict, - Iterable, - List, - Mapping, - NoReturn, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union import aiohttp +from aleph_message import parse_message from aleph_message.models import ( AggregateContent, AggregateMessage, AlephMessage, ForgetContent, ForgetMessage, - ItemHash, ItemType, MessageType, PostContent, @@ -40,28 +21,16 @@ ProgramMessage, StoreContent, StoreMessage, - parse_message, ) from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from pydantic import ValidationError from pydantic.json import pydantic_encoder -from aleph.sdk.types import Account, GenericMessage, StorageEnum -from aleph.sdk.utils import Writable, copy_async_readable_to_buffer - -from .base import BaseAlephClient, BaseAuthenticatedAlephClient -from .conf import settings -from .exceptions import ( - BroadcastError, - FileTooLarge, - InvalidMessageError, - MessageNotFoundError, - MultipleMessagesError, -) -from .models.message import MessageFilter, MessagesResponse -from .models.post import Post, PostFilter, PostsResponse -from .utils import check_unix_socket_valid, get_message_type_value +from ..conf import settings +from ..exceptions import BroadcastError, InvalidMessageError +from ..types import Account, StorageEnum +from .base import BaseAuthenticatedAlephClient +from .client import AlephClient, UserSessionSync logger = logging.getLogger(__name__) @@ -71,181 +40,6 @@ logger.info("Could not import library 'magic', MIME type detection disabled") magic = None # type:ignore -T = TypeVar("T") - - -def async_wrapper(f): - """ - Copies the docstring of wrapped functions. - """ - - wrapped = getattr(AuthenticatedAlephClient, f.__name__) - f.__doc__ = wrapped.__doc__ - - -def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """Wrap an asynchronous function into a synchronous one, - for easy use in synchronous code. - """ - - def func_caller(*args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - - # Copy wrapped function interface: - func_caller.__doc__ = func.__doc__ - func_caller.__annotations__ = func.__annotations__ - func_caller.__defaults__ = func.__defaults__ - func_caller.__kwdefaults__ = func.__kwdefaults__ - return func_caller - - -async def run_async_watcher( - *args, output_queue: queue.Queue, api_server: Optional[str], **kwargs -): - async with AlephClient(api_server=api_server) as session: - async for message in session.watch_messages(*args, **kwargs): - output_queue.put(message) - - -def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs): - asyncio.run( - run_async_watcher( - output_queue=output_queue, api_server=api_server, *args, **kwargs - ) - ) - - -class UserSessionSync: - """ - A sync version of `UserSession`, used in sync code. - - This class is returned by the context manager of `UserSession` and is - intended as a wrapper around the methods of `UserSession` and not as a public class. - The methods are fully typed to enable static type checking, but most (all) methods - should look like this (using args and kwargs for brevity, but the functions should - be fully typed): - - >>> def func(self, *args, **kwargs): - >>> return self._wrap(self.async_session.func)(*args, **kwargs) - """ - - def __init__(self, async_session: "AlephClient"): - self.async_session = async_session - - def _wrap(self, method: Callable[..., Awaitable[T]], *args, **kwargs): - return wrap_async(method)(*args, **kwargs) - - def get_messages( - self, - pagination: int = 200, - page: int = 1, - message_filter: Optional[MessageFilter] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, - ) -> MessagesResponse: - return self._wrap( - self.async_session.get_messages, - pagination=pagination, - page=page, - message_filter=message_filter, - ignore_invalid_messages=ignore_invalid_messages, - invalid_messages_log_level=invalid_messages_log_level, - ) - - # @async_wrapper - def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - ) -> GenericMessage: - return self._wrap( - self.async_session.get_message, - item_hash=item_hash, - message_type=message_type, - channel=channel, - ) - - def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: - return self._wrap(self.async_session.fetch_aggregate, address, key, limit) - - def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, - ) -> Dict[str, Dict]: - return self._wrap(self.async_session.fetch_aggregates, address, keys, limit) - - def get_posts( - self, - pagination: int = 200, - page: int = 1, - post_filter: Optional[PostFilter] = None, - ) -> PostsResponse: - return self._wrap( - self.async_session.get_posts, - pagination=pagination, - page=page, - post_filter=post_filter, - ) - - def download_file(self, file_hash: str) -> bytes: - return self._wrap(self.async_session.download_file, file_hash=file_hash) - - def download_file_ipfs(self, file_hash: str) -> bytes: - return self._wrap( - self.async_session.download_file_ipfs, - file_hash=file_hash, - ) - - def download_file_to_buffer( - self, file_hash: str, output_buffer: Writable[bytes] - ) -> None: - return self._wrap( - self.async_session.download_file_to_buffer, - file_hash=file_hash, - output_buffer=output_buffer, - ) - - def download_file_ipfs_to_buffer( - self, file_hash: str, output_buffer: Writable[bytes] - ) -> None: - return self._wrap( - self.async_session.download_file_ipfs_to_buffer, - file_hash=file_hash, - output_buffer=output_buffer, - ) - - def watch_messages( - self, - message_filter: Optional[MessageFilter] = None, - ) -> Iterable[AlephMessage]: - """ - Iterate over current and future matching messages synchronously. - - Runs the `watch_messages` asynchronous generator in a thread. - """ - output_queue: queue.Queue[AlephMessage] = queue.Queue() - thread = threading.Thread( - target=watcher_thread, - args=( - output_queue, - self.async_session.api_server, - message_filter, - {}, - ), - ) - thread.start() - while True: - yield output_queue.get() - class AuthenticatedUserSessionSync(UserSessionSync): async_session: "AuthenticatedAlephClient" @@ -414,322 +208,6 @@ def submit( ) -class AlephClient(BaseAlephClient): - api_server: str - http_session: aiohttp.ClientSession - - def __init__( - self, - api_server: Optional[str] = None, - api_unix_socket: Optional[str] = None, - allow_unix_sockets: bool = True, - timeout: Optional[aiohttp.ClientTimeout] = None, - ): - """AlephClient can use HTTP(S) or HTTP over Unix sockets. - Unix sockets are used when running inside a virtual machine, - and can be shared across containers in a more secure way than TCP ports. - """ - self.api_server = api_server or settings.API_HOST - if not self.api_server: - raise ValueError("Missing API host") - - unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET - if unix_socket_path and allow_unix_sockets: - check_unix_socket_valid(unix_socket_path) - connector = aiohttp.UnixConnector(path=unix_socket_path) - else: - connector = None - - # ClientSession timeout defaults to a private sentinel object and may not be None. - self.http_session = ( - aiohttp.ClientSession( - base_url=self.api_server, connector=connector, timeout=timeout - ) - if timeout - else aiohttp.ClientSession( - base_url=self.api_server, - connector=connector, - ) - ) - - def __enter__(self) -> UserSessionSync: - return UserSessionSync(async_session=self) - - def __exit__(self, exc_type, exc_val, exc_tb): - close_fut = self.http_session.close() - try: - loop = asyncio.get_running_loop() - loop.run_until_complete(close_fut) - except RuntimeError: - asyncio.run(close_fut) - - async def __aenter__(self) -> "AlephClient": - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.http_session.close() - - async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: - params: Dict[str, Any] = {"keys": key} - - async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", params=params - ) as resp: - resp.raise_for_status() - result = await resp.json() - data = result.get("data", dict()) - return data.get(key) - - async def fetch_aggregates( - self, address: str, keys: Optional[Iterable[str]] = None - ) -> Dict[str, Dict]: - keys_str = ",".join(keys) if keys else "" - params: Dict[str, Any] = {} - if keys_str: - params["keys"] = keys_str - - async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", - params=params, - ) as resp: - resp.raise_for_status() - result = await resp.json() - data = result.get("data", dict()) - return data - - async def get_posts( - self, - pagination: int = 200, - page: int = 1, - post_filter: Optional[PostFilter] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, - ) -> PostsResponse: - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - if not post_filter: - post_filter = PostFilter() - params = post_filter.as_http_params() - params["page"] = str(page) - params["pagination"] = str(pagination) - - async with self.http_session.get("/api/v0/posts.json", params=params) as resp: - resp.raise_for_status() - response_json = await resp.json() - posts_raw = response_json["posts"] - - posts: List[Post] = [] - for post_raw in posts_raw: - try: - posts.append(Post.parse_obj(post_raw)) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - return PostsResponse( - posts=posts, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) - - async def download_file_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], - ) -> None: - """ - Download a file from the storage engine and write it to the specified output buffer. - :param file_hash: The hash of the file to retrieve. - :param output_buffer: Writable binary buffer. The file will be written to this buffer. - """ - - async with self.http_session.get( - f"/api/v0/storage/raw/{file_hash}" - ) as response: - if response.status == 200: - await copy_async_readable_to_buffer( - response.content, output_buffer, chunk_size=16 * 1024 - ) - if response.status == 413: - ipfs_hash = ItemHash(file_hash) - if ipfs_hash.item_type == ItemType.ipfs: - return await self.download_file_ipfs_to_buffer( - file_hash, output_buffer - ) - else: - raise FileTooLarge(f"The file from {file_hash} is too large") - - async def download_file_ipfs_to_buffer( - self, - file_hash: str, - output_buffer: Writable[bytes], - ) -> None: - """ - Download a file from the storage engine and write it to the specified output buffer. - - :param file_hash: The hash of the file to retrieve. - :param output_buffer: The binary output buffer to write the file data to. - """ - async with aiohttp.ClientSession() as session: - async with session.get( - f"https://ipfs.aleph.im/ipfs/{file_hash}" - ) as response: - if response.status == 200: - await copy_async_readable_to_buffer( - response.content, output_buffer, chunk_size=16 * 1024 - ) - else: - response.raise_for_status() - - async def download_file( - self, - file_hash: str, - ) -> bytes: - """ - Get a file from the storage engine as raw bytes. - - Warning: Downloading large files can be slow and memory intensive. - - :param file_hash: The hash of the file to retrieve. - """ - buffer = BytesIO() - await self.download_file_to_buffer(file_hash, output_buffer=buffer) - return buffer.getvalue() - - async def download_file_ipfs( - self, - file_hash: str, - ) -> bytes: - """ - Get a file from the ipfs storage engine as raw bytes. - - Warning: Downloading large files can be slow. - - :param file_hash: The hash of the file to retrieve. - """ - buffer = BytesIO() - await self.download_file_ipfs_to_buffer(file_hash, output_buffer=buffer) - return buffer.getvalue() - - async def get_messages( - self, - pagination: int = 200, - page: int = 1, - message_filter: Optional[MessageFilter] = None, - ignore_invalid_messages: Optional[bool] = True, - invalid_messages_log_level: Optional[int] = logging.NOTSET, - ) -> MessagesResponse: - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - if not message_filter: - message_filter = MessageFilter() - params = message_filter.as_http_params() - params["page"] = str(page) - params["pagination"] = str(pagination) - async with self.http_session.get( - "/api/v0/messages.json", params=params - ) as resp: - resp.raise_for_status() - response_json = await resp.json() - messages_raw = response_json["messages"] - - # All messages may not be valid according to the latest specification in - # aleph-message. This allows the user to specify how errors should be handled. - messages: List[AlephMessage] = [] - for message_raw in messages_raw: - try: - message = parse_message(message_raw) - messages.append(message) - except KeyError as e: - if not ignore_invalid_messages: - raise e - logger.log( - level=invalid_messages_log_level, - msg=f"KeyError: Field '{e.args[0]}' not found", - ) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - - return MessagesResponse( - messages=messages, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) - - async def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - ) -> GenericMessage: - messages_response = await self.get_messages( - message_filter=MessageFilter( - hashes=[item_hash], - channels=[channel] if channel else None, - ) - ) - if len(messages_response.messages) < 1: - raise MessageNotFoundError(f"No such hash {item_hash}") - if len(messages_response.messages) != 1: - raise MultipleMessagesError( - f"Multiple messages found for the same item_hash `{item_hash}`" - ) - message: GenericMessage = messages_response.messages[0] - if message_type: - expected_type = get_message_type_value(message_type) - if message.type != expected_type: - raise TypeError( - f"The message type '{message.type}' " - f"does not match the expected type '{expected_type}'" - ) - return message - - async def watch_messages( - self, - message_filter: Optional[MessageFilter] = None, - ) -> AsyncIterable[AlephMessage]: - if not message_filter: - message_filter = MessageFilter() - params = message_filter.as_http_params() - - async with self.http_session.ws_connect( - "/api/ws0/messages", params=params - ) as ws: - logger.debug("Websocket connected") - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == "close cmd": - await ws.close() - break - else: - data = json.loads(msg.data) - yield parse_message(data) - elif msg.type == aiohttp.WSMsgType.ERROR: - break - - class AuthenticatedAlephClient(AlephClient, BaseAuthenticatedAlephClient): account: Account diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/client/base.py similarity index 97% rename from src/aleph/sdk/base.py rename to src/aleph/sdk/client/base.py index 9d80bc64..e5cb9c0a 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/client/base.py @@ -25,9 +25,9 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from .models.message import MessageFilter -from .models.post import PostFilter, PostsResponse -from .types import GenericMessage, StorageEnum +from ..models.message import MessageFilter +from ..models.post import PostFilter, PostsResponse +from ..types import GenericMessage, StorageEnum DEFAULT_PAGE_SIZE = 200 @@ -58,7 +58,7 @@ async def fetch_aggregates( @abstractmethod async def get_posts( self, - pagination: int = DEFAULT_PAGE_SIZE, + page_size: int = DEFAULT_PAGE_SIZE, page: int = 1, post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, @@ -67,7 +67,7 @@ async def get_posts( """ Fetch a list of posts from the network. - :param pagination: Number of items to fetch (Default: 200) + :param page_size: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) @@ -113,7 +113,7 @@ async def download_file( @abstractmethod async def get_messages( self, - pagination: int = DEFAULT_PAGE_SIZE, + page_size: int = DEFAULT_PAGE_SIZE, page: int = 1, message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, @@ -122,7 +122,7 @@ async def get_messages( """ Fetch a list of messages from the network. - :param pagination: Number of items to fetch (Default: 200) + :param page_size: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) diff --git a/src/aleph/sdk/client/client.py b/src/aleph/sdk/client/client.py new file mode 100644 index 00000000..8e96645a --- /dev/null +++ b/src/aleph/sdk/client/client.py @@ -0,0 +1,499 @@ +import asyncio +import json +import logging +import queue +import threading +from io import BytesIO +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Type, +) + +import aiohttp +from aleph_message import parse_message +from aleph_message.models import AlephMessage, ItemHash, ItemType +from pydantic import ValidationError + +from ..conf import settings +from ..exceptions import FileTooLarge, MessageNotFoundError, MultipleMessagesError +from ..models.message import MessageFilter, MessagesResponse +from ..models.post import Post, PostFilter, PostsResponse +from ..types import GenericMessage +from ..utils import ( + Writable, + check_unix_socket_valid, + copy_async_readable_to_buffer, + get_message_type_value, +) +from .base import BaseAlephClient +from .utils import T, wrap_async + +logger = logging.getLogger(__name__) + + +async def run_async_watcher( + *args, output_queue: queue.Queue, api_server: Optional[str], **kwargs +): + async with AlephClient(api_server=api_server) as session: + async for message in session.watch_messages(*args, **kwargs): + output_queue.put(message) + + +def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs): + asyncio.run( + run_async_watcher( + output_queue=output_queue, api_server=api_server, *args, **kwargs + ) + ) + + +class UserSessionSync: + """ + A sync version of `UserSession`, used in sync code. + + This class is returned by the context manager of `UserSession` and is + intended as a wrapper around the methods of `UserSession` and not as a public class. + The methods are fully typed to enable static type checking, but most (all) methods + should look like this (using args and kwargs for brevity, but the functions should + be fully typed): + + >>> def func(self, *args, **kwargs): + >>> return self._wrap(self.async_session.func)(*args, **kwargs) + """ + + def __init__(self, async_session: "AlephClient"): + self.async_session = async_session + + def _wrap(self, method: Callable[..., Awaitable[T]], *args, **kwargs): + return wrap_async(method)(*args, **kwargs) + + def get_messages( + self, + page_size: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + return self._wrap( + self.async_session.get_messages, + page_size=page_size, + page=page, + message_filter=message_filter, + ignore_invalid_messages=ignore_invalid_messages, + invalid_messages_log_level=invalid_messages_log_level, + ) + + # @async_wrapper + def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + return self._wrap( + self.async_session.get_message, + item_hash=item_hash, + message_type=message_type, + channel=channel, + ) + + def fetch_aggregate( + self, + address: str, + key: str, + ) -> Dict[str, Dict]: + return self._wrap(self.async_session.fetch_aggregate, address, key) + + def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + ) -> Dict[str, Dict]: + return self._wrap(self.async_session.fetch_aggregates, address, keys) + + def get_posts( + self, + page_size: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ) -> PostsResponse: + return self._wrap( + self.async_session.get_posts, + page_size=page_size, + page=page, + post_filter=post_filter, + ) + + def download_file(self, file_hash: str) -> bytes: + return self._wrap(self.async_session.download_file, file_hash=file_hash) + + def download_file_ipfs(self, file_hash: str) -> bytes: + return self._wrap( + self.async_session.download_file_ipfs, + file_hash=file_hash, + ) + + def download_file_to_buffer( + self, file_hash: str, output_buffer: Writable[bytes] + ) -> None: + return self._wrap( + self.async_session.download_file_to_buffer, + file_hash=file_hash, + output_buffer=output_buffer, + ) + + def download_file_ipfs_to_buffer( + self, file_hash: str, output_buffer: Writable[bytes] + ) -> None: + return self._wrap( + self.async_session.download_file_ipfs_to_buffer, + file_hash=file_hash, + output_buffer=output_buffer, + ) + + def watch_messages( + self, + message_filter: Optional[MessageFilter] = None, + ) -> Iterable[AlephMessage]: + """ + Iterate over current and future matching messages synchronously. + + Runs the `watch_messages` asynchronous generator in a thread. + """ + output_queue: queue.Queue[AlephMessage] = queue.Queue() + thread = threading.Thread( + target=watcher_thread, + args=( + output_queue, + self.async_session.api_server, + message_filter, + {}, + ), + ) + thread.start() + while True: + yield output_queue.get() + + +class AlephClient(BaseAlephClient): + api_server: str + http_session: aiohttp.ClientSession + + def __init__( + self, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ): + """AlephClient can use HTTP(S) or HTTP over Unix sockets. + Unix sockets are used when running inside a virtual machine, + and can be shared across containers in a more secure way than TCP ports. + """ + self.api_server = api_server or settings.API_HOST + if not self.api_server: + raise ValueError("Missing API host") + + unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET + if unix_socket_path and allow_unix_sockets: + check_unix_socket_valid(unix_socket_path) + connector = aiohttp.UnixConnector(path=unix_socket_path) + else: + connector = None + + # ClientSession timeout defaults to a private sentinel object and may not be None. + self.http_session = ( + aiohttp.ClientSession( + base_url=self.api_server, connector=connector, timeout=timeout + ) + if timeout + else aiohttp.ClientSession( + base_url=self.api_server, + connector=connector, + ) + ) + + def __enter__(self) -> UserSessionSync: + return UserSessionSync(async_session=self) + + def __exit__(self, exc_type, exc_val, exc_tb): + close_fut = self.http_session.close() + try: + loop = asyncio.get_running_loop() + loop.run_until_complete(close_fut) + except RuntimeError: + asyncio.run(close_fut) + + async def __aenter__(self) -> "AlephClient": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.http_session.close() + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: + params: Dict[str, Any] = {"keys": key} + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", params=params + ) as resp: + resp.raise_for_status() + result = await resp.json() + data = result.get("data", dict()) + return data.get(key) + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + keys_str = ",".join(keys) if keys else "" + params: Dict[str, Any] = {} + if keys_str: + params["keys"] = keys_str + + async with self.http_session.get( + f"/api/v0/aggregates/{address}.json", + params=params, + ) as resp: + resp.raise_for_status() + result = await resp.json() + data = result.get("data", dict()) + return data + + async def get_posts( + self, + page_size: int = 200, + page: int = 1, + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> PostsResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + if not post_filter: + post_filter = PostFilter() + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(page_size) + + async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + resp.raise_for_status() + response_json = await resp.json() + posts_raw = response_json["posts"] + + posts: List[Post] = [] + for post_raw in posts_raw: + try: + posts.append(Post.parse_obj(post_raw)) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + return PostsResponse( + posts=posts, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + :param file_hash: The hash of the file to retrieve. + :param output_buffer: Writable binary buffer. The file will be written to this buffer. + """ + + async with self.http_session.get( + f"/api/v0/storage/raw/{file_hash}" + ) as response: + if response.status == 200: + await copy_async_readable_to_buffer( + response.content, output_buffer, chunk_size=16 * 1024 + ) + if response.status == 413: + ipfs_hash = ItemHash(file_hash) + if ipfs_hash.item_type == ItemType.ipfs: + return await self.download_file_ipfs_to_buffer( + file_hash, output_buffer + ) + else: + raise FileTooLarge(f"The file from {file_hash} is too large") + + async def download_file_ipfs_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + + :param file_hash: The hash of the file to retrieve. + :param output_buffer: The binary output buffer to write the file data to. + """ + async with aiohttp.ClientSession() as session: + async with session.get( + f"https://ipfs.aleph.im/ipfs/{file_hash}" + ) as response: + if response.status == 200: + await copy_async_readable_to_buffer( + response.content, output_buffer, chunk_size=16 * 1024 + ) + else: + response.raise_for_status() + + async def download_file( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the storage engine as raw bytes. + + Warning: Downloading large files can be slow and memory intensive. + + :param file_hash: The hash of the file to retrieve. + """ + buffer = BytesIO() + await self.download_file_to_buffer(file_hash, output_buffer=buffer) + return buffer.getvalue() + + async def download_file_ipfs( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the ipfs storage engine as raw bytes. + + Warning: Downloading large files can be slow. + + :param file_hash: The hash of the file to retrieve. + """ + buffer = BytesIO() + await self.download_file_ipfs_to_buffer(file_hash, output_buffer=buffer) + return buffer.getvalue() + + async def get_messages( + self, + page_size: int = 200, + page: int = 1, + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> MessagesResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(page_size) + async with self.http_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + response_json = await resp.json() + messages_raw = response_json["messages"] + + # All messages may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = parse_message(message_raw) + messages.append(message) + except KeyError as e: + if not ignore_invalid_messages: + raise e + logger.log( + level=invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + + return MessagesResponse( + messages=messages, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + messages_response = await self.get_messages( + message_filter=MessageFilter( + hashes=[item_hash], + channels=[channel] if channel else None, + ) + ) + if len(messages_response.messages) < 1: + raise MessageNotFoundError(f"No such hash {item_hash}") + if len(messages_response.messages) != 1: + raise MultipleMessagesError( + f"Multiple messages found for the same item_hash `{item_hash}`" + ) + message: GenericMessage = messages_response.messages[0] + if message_type: + expected_type = get_message_type_value(message_type) + if message.type != expected_type: + raise TypeError( + f"The message type '{message.type}' " + f"does not match the expected type '{expected_type}'" + ) + return message + + async def watch_messages( + self, + message_filter: Optional[MessageFilter] = None, + ) -> AsyncIterable[AlephMessage]: + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() + + async with self.http_session.ws_connect( + "/api/ws0/messages", params=params + ) as ws: + logger.debug("Websocket connected") + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "close cmd": + await ws.close() + break + else: + data = json.loads(msg.data) + yield parse_message(data) + elif msg.type == aiohttp.WSMsgType.ERROR: + break diff --git a/src/aleph/sdk/client/utils.py b/src/aleph/sdk/client/utils.py new file mode 100644 index 00000000..8af8cec3 --- /dev/null +++ b/src/aleph/sdk/client/utils.py @@ -0,0 +1,21 @@ +import asyncio +from typing import Awaitable, Callable, TypeVar + +T = TypeVar("T") + + +def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: + """Wrap an asynchronous function into a synchronous one, + for easy use in synchronous code. + """ + + def func_caller(*args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(func(*args, **kwargs)) + + # Copy wrapped function interface: + func_caller.__doc__ = func.__doc__ + func_caller.__annotations__ = func.__annotations__ + func_caller.__defaults__ = func.__defaults__ + func_caller.__kwdefaults__ = func.__kwdefaults__ + return func_caller diff --git a/src/aleph/sdk/models/common.py b/src/aleph/sdk/models/common.py index c7e0dc30..6892d881 100644 --- a/src/aleph/sdk/models/common.py +++ b/src/aleph/sdk/models/common.py @@ -1,6 +1,3 @@ -from datetime import datetime -from typing import Iterable, Optional, Union - from pydantic import BaseModel @@ -9,21 +6,3 @@ class PaginationResponse(BaseModel): pagination_total: int pagination_per_page: int pagination_item: str - - -def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: - if values: - return ",".join(values) - else: - return None - - -def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: - if date is None: - return None - elif isinstance(date, float): - return date - elif hasattr(date, "timestamp"): - return date.timestamp() - else: - raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/src/aleph/sdk/models/message.py b/src/aleph/sdk/models/message.py index 4ba6a1b2..4fdb295d 100644 --- a/src/aleph/sdk/models/message.py +++ b/src/aleph/sdk/models/message.py @@ -3,7 +3,8 @@ from aleph_message.models import AlephMessage, MessageType -from .common import PaginationResponse, _date_field_to_float, serialize_list +from .common import PaginationResponse +from .utils import _date_field_to_float, serialize_list class MessagesResponse(PaginationResponse): diff --git a/src/aleph/sdk/models/post.py b/src/aleph/sdk/models/post.py index 09a301c2..16669faf 100644 --- a/src/aleph/sdk/models/post.py +++ b/src/aleph/sdk/models/post.py @@ -4,7 +4,8 @@ from aleph_message.models import Chain, ItemHash, ItemType, MessageConfirmation from pydantic import BaseModel, Field -from .common import PaginationResponse, _date_field_to_float, serialize_list +from .common import PaginationResponse +from .utils import _date_field_to_float, serialize_list class Post(BaseModel): diff --git a/src/aleph/sdk/models/utils.py b/src/aleph/sdk/models/utils.py new file mode 100644 index 00000000..818e4c70 --- /dev/null +++ b/src/aleph/sdk/models/utils.py @@ -0,0 +1,20 @@ +from datetime import datetime +from typing import Iterable, Optional, Union + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 2db88ef3..7773f9b2 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -5,7 +5,7 @@ import pytest from aleph_message.models import MessagesResponse, MessageType -from aleph.sdk.client import AlephClient +from aleph.sdk import AlephClient from aleph.sdk.conf import settings from aleph.sdk.models.message import MessageFilter from aleph.sdk.models.post import PostFilter, PostsResponse @@ -72,7 +72,7 @@ async def test_fetch_aggregates(): async def test_get_posts(): async with AlephClient(api_server=settings.API_HOST) as session: response: PostsResponse = await session.get_posts( - pagination=2, + page_size=2, post_filter=PostFilter( channels=["TEST"], ), @@ -86,7 +86,7 @@ async def test_get_posts(): async def test_get_messages(): async with AlephClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( - pagination=2, + page_size=2, message_filter=MessageFilter( message_types=[MessageType.post], ), diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index 0788a1ab..e3b8c0ed 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -1,6 +1,6 @@ from aleph_message.models import MessagesResponse, MessageType -from aleph.sdk.client import AlephClient +from aleph.sdk import AlephClient from aleph.sdk.conf import settings from aleph.sdk.models.message import MessageFilter @@ -8,7 +8,7 @@ def test_get_post_messages(): with AlephClient(api_server=settings.API_HOST) as session: response: MessagesResponse = session.get_messages( - pagination=2, + page_size=2, message_filter=MessageFilter( message_types=[MessageType.post], ), From bc12555e5680e6b4d424ac403b18d6bfc44e7151 Mon Sep 17 00:00:00 2001 From: mhh Date: Fri, 6 Oct 2023 19:19:38 +0200 Subject: [PATCH 4/6] Optimization: Unnecessary class instantiation and method calls could slow down performance Solution: Prevent unnecessary .as_http_params() calls --- src/aleph/sdk/client/client.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/aleph/sdk/client/client.py b/src/aleph/sdk/client/client.py index 8e96645a..ce5c43db 100644 --- a/src/aleph/sdk/client/client.py +++ b/src/aleph/sdk/client/client.py @@ -284,10 +284,14 @@ async def get_posts( ) if not post_filter: - post_filter = PostFilter() - params = post_filter.as_http_params() - params["page"] = str(page) - params["pagination"] = str(page_size) + params = { + "page": str(page), + "pagination": str(page_size), + } + else: + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(page_size) async with self.http_session.get("/api/v0/posts.json", params=params) as resp: resp.raise_for_status() @@ -408,10 +412,15 @@ async def get_messages( ) if not message_filter: - message_filter = MessageFilter() - params = message_filter.as_http_params() - params["page"] = str(page) - params["pagination"] = str(page_size) + params = { + "page": str(page), + "pagination": str(page_size), + } + else: + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(page_size) + async with self.http_session.get( "/api/v0/messages.json", params=params ) as resp: @@ -479,8 +488,7 @@ async def watch_messages( self, message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: - if not message_filter: - message_filter = MessageFilter() + message_filter = message_filter or MessageFilter() params = message_filter.as_http_params() async with self.http_session.ws_connect( From a81e1516ff61d4a48b12babbd787236bb32d0981 Mon Sep 17 00:00:00 2001 From: Hugo Herter Date: Tue, 10 Oct 2023 17:38:56 +0200 Subject: [PATCH 5/6] Remove/Refactor: Hacky code for sync clients is hard to maintain; models module is too broad of a name; client module classes had unclear naming Solution: Restructure and remove the sync wrapper; Move models' functions into new query module and existing files; Rename client classes --- examples/httpgateway.py | 4 +- examples/metrics.py | 6 +- examples/mqtt.py | 6 +- examples/store.py | 4 +- src/aleph/sdk/__init__.py | 4 +- src/aleph/sdk/client/__init__.py | 12 +- src/aleph/sdk/client/{base.py => abstract.py} | 63 ++++- ...authenticated.py => authenticated_http.py} | 252 ++++-------------- src/aleph/sdk/client/{client.py => http.py} | 182 +------------ src/aleph/sdk/client/utils.py | 21 -- src/aleph/sdk/models/__init__.py | 0 src/aleph/sdk/models/common.py | 8 - src/aleph/sdk/models/post.py | 123 --------- src/aleph/sdk/models/utils.py | 20 -- .../{models/message.py => query/filters.py} | 81 +++++- src/aleph/sdk/query/responses.py | 74 +++++ src/aleph/sdk/utils.py | 21 +- tests/integration/itest_aggregates.py | 6 +- tests/integration/itest_forget.py | 16 +- tests/integration/itest_posts.py | 8 +- tests/integration/toolkit.py | 2 +- tests/unit/test_asynchronous.py | 6 +- tests/unit/test_asynchronous_get.py | 14 +- tests/unit/test_download.py | 6 +- tests/unit/test_synchronous_get.py | 20 -- 25 files changed, 320 insertions(+), 639 deletions(-) rename src/aleph/sdk/client/{base.py => abstract.py} (87%) rename src/aleph/sdk/client/{authenticated.py => authenticated_http.py} (71%) rename src/aleph/sdk/client/{client.py => http.py} (69%) delete mode 100644 src/aleph/sdk/client/utils.py delete mode 100644 src/aleph/sdk/models/__init__.py delete mode 100644 src/aleph/sdk/models/common.py delete mode 100644 src/aleph/sdk/models/post.py delete mode 100644 src/aleph/sdk/models/utils.py rename src/aleph/sdk/{models/message.py => query/filters.py} (59%) create mode 100644 src/aleph/sdk/query/responses.py delete mode 100644 tests/unit/test_synchronous_get.py diff --git a/examples/httpgateway.py b/examples/httpgateway.py index c2cb3bdf..24ed6ba1 100644 --- a/examples/httpgateway.py +++ b/examples/httpgateway.py @@ -7,7 +7,7 @@ from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient app = web.Application() routes = web.RouteTableDef() @@ -32,7 +32,7 @@ async def source_post(request): return web.json_response( {"status": "error", "message": "unauthorized secret"} ) - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=app["account"], api_server="https://api2.aleph.im" ) as session: message, _status = await session.create_post( diff --git a/examples/metrics.py b/examples/metrics.py index 381db6be..c7beb5d2 100644 --- a/examples/metrics.py +++ b/examples/metrics.py @@ -12,7 +12,7 @@ from aleph_message.status import MessageStatus from aleph.sdk.chains.ethereum import get_fallback_account -from aleph.sdk.client import AuthenticatedAlephClient, AuthenticatedUserSessionSync +from aleph.sdk.client import AuthenticatedAlephClientSync, AuthenticatedAlephHttpClient from aleph.sdk.conf import settings @@ -54,7 +54,7 @@ def get_cpu_cores(): def send_metrics( - session: AuthenticatedUserSessionSync, metrics + session: AuthenticatedAlephClientSync, metrics ) -> Tuple[AlephMessage, MessageStatus]: return session.create_aggregate(key="metrics", content=metrics, channel="SYSINFO") @@ -70,7 +70,7 @@ def collect_metrics(): def main(): account = get_fallback_account() - with AuthenticatedAlephClient( + with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: while True: diff --git a/examples/mqtt.py b/examples/mqtt.py index eff32121..e09b2c6f 100644 --- a/examples/mqtt.py +++ b/examples/mqtt.py @@ -10,7 +10,7 @@ from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings @@ -27,7 +27,7 @@ def get_input_data(value): def send_metrics(account, metrics): - with AuthenticatedAlephClient( + with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: return session.create_aggregate( @@ -100,7 +100,7 @@ async def gateway( if not userdata["received"]: await client.reconnect() - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: for key, value in state.items(): diff --git a/examples/store.py b/examples/store.py index 6ce5662c..b6c7a862 100644 --- a/examples/store.py +++ b/examples/store.py @@ -6,7 +6,7 @@ from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings DEFAULT_SERVER = "https://api2.aleph.im" @@ -23,7 +23,7 @@ async def print_output_hash(message: StoreMessage, status: MessageStatus): async def do_upload(account, engine, channel, filename=None, file_hash=None): - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: print(filename, account.get_address()) diff --git a/src/aleph/sdk/__init__.py b/src/aleph/sdk/__init__.py index c66fe9d6..c14b64f6 100644 --- a/src/aleph/sdk/__init__.py +++ b/src/aleph/sdk/__init__.py @@ -1,6 +1,6 @@ from pkg_resources import DistributionNotFound, get_distribution -from aleph.sdk.client import AlephClient, AuthenticatedAlephClient +from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient try: # Change here if project is renamed and does not equal the package name @@ -11,4 +11,4 @@ finally: del get_distribution, DistributionNotFound -__all__ = ["AlephClient", "AuthenticatedAlephClient"] +__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient"] diff --git a/src/aleph/sdk/client/__init__.py b/src/aleph/sdk/client/__init__.py index 8b0db873..9ee25dd9 100644 --- a/src/aleph/sdk/client/__init__.py +++ b/src/aleph/sdk/client/__init__.py @@ -1,12 +1,10 @@ -from .authenticated import AuthenticatedAlephClient, AuthenticatedUserSessionSync -from .base import BaseAlephClient, BaseAuthenticatedAlephClient -from .client import AlephClient, UserSessionSync +from .abstract import AlephClient, AuthenticatedAlephClient +from .authenticated_http import AuthenticatedAlephHttpClient +from .http import AlephHttpClient __all__ = [ - "BaseAlephClient", - "BaseAuthenticatedAlephClient", "AlephClient", "AuthenticatedAlephClient", - "UserSessionSync", - "AuthenticatedUserSessionSync", + "AlephHttpClient", + "AuthenticatedAlephHttpClient", ] diff --git a/src/aleph/sdk/client/base.py b/src/aleph/sdk/client/abstract.py similarity index 87% rename from src/aleph/sdk/client/base.py rename to src/aleph/sdk/client/abstract.py index e5cb9c0a..26a51221 100644 --- a/src/aleph/sdk/client/base.py +++ b/src/aleph/sdk/client/abstract.py @@ -25,14 +25,15 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from ..models.message import MessageFilter -from ..models.post import PostFilter, PostsResponse +from ..query.filters import MessageFilter, PostFilter +from ..query.responses import PostsResponse from ..types import GenericMessage, StorageEnum +from ..utils import Writable DEFAULT_PAGE_SIZE = 200 -class BaseAlephClient(ABC): +class AlephClient(ABC): @abstractmethod async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: """ @@ -110,6 +111,44 @@ async def download_file( """ pass + async def download_file_ipfs( + self, + file_hash: str, + ) -> bytes: + """ + Get a file from the ipfs storage engine as raw bytes. + + Warning: Downloading large files can be slow. + + :param file_hash: The hash of the file to retrieve. + """ + raise NotImplementedError() + + async def download_file_ipfs_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + + :param file_hash: The hash of the file to retrieve. + :param output_buffer: The binary output buffer to write the file data to. + """ + raise NotImplementedError() + + async def download_file_to_buffer( + self, + file_hash: str, + output_buffer: Writable[bytes], + ) -> None: + """ + Download a file from the storage engine and write it to the specified output buffer. + :param file_hash: The hash of the file to retrieve. + :param output_buffer: Writable binary buffer. The file will be written to this buffer. + """ + raise NotImplementedError() + @abstractmethod async def get_messages( self, @@ -180,7 +219,7 @@ def watch_messages( pass -class BaseAuthenticatedAlephClient(BaseAlephClient): +class AuthenticatedAlephClient(AlephClient): @abstractmethod async def create_post( self, @@ -350,3 +389,19 @@ async def submit( :param sync: If true, waits for the message to be processed by the API server (Default: False) """ pass + + async def ipfs_push(self, content: Mapping) -> str: + """ + Push a file to IPFS. + + :param content: Content of the file to push + """ + raise NotImplementedError() + + async def storage_push(self, content: Mapping) -> str: + """ + Push arbitrary content as JSON to the storage service. + + :param content: The dict-like content to upload + """ + raise NotImplementedError() diff --git a/src/aleph/sdk/client/authenticated.py b/src/aleph/sdk/client/authenticated_http.py similarity index 71% rename from src/aleph/sdk/client/authenticated.py rename to src/aleph/sdk/client/authenticated_http.py index 093dbe76..6291467a 100644 --- a/src/aleph/sdk/client/authenticated.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -23,14 +23,20 @@ StoreMessage, ) from aleph_message.models.execution.base import Encoding +from aleph_message.models.execution.environment import ( + FunctionEnvironment, + MachineResources, +) +from aleph_message.models.execution.program import CodeContent, FunctionRuntime +from aleph_message.models.execution.volume import MachineVolume from aleph_message.status import MessageStatus from pydantic.json import pydantic_encoder from ..conf import settings from ..exceptions import BroadcastError, InvalidMessageError from ..types import Account, StorageEnum -from .base import BaseAuthenticatedAlephClient -from .client import AlephClient, UserSessionSync +from .abstract import AuthenticatedAlephClient +from .http import AlephHttpClient logger = logging.getLogger(__name__) @@ -41,174 +47,7 @@ magic = None # type:ignore -class AuthenticatedUserSessionSync(UserSessionSync): - async_session: "AuthenticatedAlephClient" - - def __init__(self, async_session: "AuthenticatedAlephClient"): - super().__init__(async_session=async_session) - - def ipfs_push(self, content: Mapping) -> str: - return self._wrap(self.async_session.ipfs_push, content=content) - - def storage_push(self, content: Mapping) -> str: - return self._wrap(self.async_session.storage_push, content=content) - - def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: - return self._wrap(self.async_session.ipfs_push_file, file_content=file_content) - - def storage_push_file(self, file_content: Union[str, bytes]) -> str: - return self._wrap( - self.async_session.storage_push_file, file_content=file_content - ) - - def create_post( - self, - post_content, - post_type: str, - ref: Optional[str] = None, - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - sync: bool = False, - ) -> Tuple[PostMessage, MessageStatus]: - return self._wrap( - self.async_session.create_post, - post_content=post_content, - post_type=post_type, - ref=ref, - address=address, - channel=channel, - inline=inline, - storage_engine=storage_engine, - sync=sync, - ) - - def create_aggregate( - self, - key: str, - content: Mapping[str, Any], - address: Optional[str] = None, - channel: Optional[str] = None, - inline: bool = True, - sync: bool = False, - ) -> Tuple[AggregateMessage, MessageStatus]: - return self._wrap( - self.async_session.create_aggregate, - key=key, - content=content, - address=address, - channel=channel, - inline=inline, - sync=sync, - ) - - def create_store( - self, - address: Optional[str] = None, - file_content: Optional[bytes] = None, - file_path: Optional[Union[str, Path]] = None, - file_hash: Optional[str] = None, - guess_mime_type: bool = False, - ref: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - extra_fields: Optional[dict] = None, - channel: Optional[str] = None, - sync: bool = False, - ) -> Tuple[StoreMessage, MessageStatus]: - return self._wrap( - self.async_session.create_store, - address=address, - file_content=file_content, - file_path=file_path, - file_hash=file_hash, - guess_mime_type=guess_mime_type, - ref=ref, - storage_engine=storage_engine, - extra_fields=extra_fields, - channel=channel, - sync=sync, - ) - - def create_program( - self, - program_ref: str, - entrypoint: str, - runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, - vcpus: Optional[int] = None, - timeout_seconds: Optional[float] = None, - persistent: bool = False, - encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, - ) -> Tuple[ProgramMessage, MessageStatus]: - return self._wrap( - self.async_session.create_program, - program_ref=program_ref, - entrypoint=entrypoint, - runtime=runtime, - environment_variables=environment_variables, - storage_engine=storage_engine, - channel=channel, - address=address, - sync=sync, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - subscriptions=subscriptions, - metadata=metadata, - ) - - def forget( - self, - hashes: List[str], - reason: Optional[str], - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, - address: Optional[str] = None, - sync: bool = False, - ) -> Tuple[ForgetMessage, MessageStatus]: - return self._wrap( - self.async_session.forget, - hashes=hashes, - reason=reason, - storage_engine=storage_engine, - channel=channel, - address=address, - sync=sync, - ) - - def submit( - self, - content: Dict[str, Any], - message_type: MessageType, - channel: Optional[str] = None, - storage_engine: StorageEnum = StorageEnum.storage, - allow_inlining: bool = True, - sync: bool = False, - ) -> Tuple[AlephMessage, MessageStatus]: - return self._wrap( - self.async_session.submit, - content=content, - message_type=message_type, - channel=channel, - storage_engine=storage_engine, - allow_inlining=allow_inlining, - sync=sync, - ) - - -class AuthenticatedAlephClient(AlephClient, BaseAuthenticatedAlephClient): +class AuthenticatedAlephHttpClient(AlephHttpClient, AuthenticatedAlephClient): account: Account BROADCAST_MESSAGE_FIELDS = { @@ -239,10 +78,7 @@ def __init__( ) self.account = account - def __enter__(self) -> "AuthenticatedUserSessionSync": - return AuthenticatedUserSessionSync(async_session=self) - - async def __aenter__(self) -> "AuthenticatedAlephClient": + async def __aenter__(self) -> "AuthenticatedAlephHttpClient": return self async def ipfs_push(self, content: Mapping) -> str: @@ -580,40 +416,42 @@ async def create_program( # Trigger on HTTP calls. triggers = {"http": True, "persistent": persistent} + volumes: List[MachineVolume] = [ + MachineVolume.parse_obj(volume) for volume in volumes + ] + content = ProgramContent( - **{ - "type": "vm-function", - "address": address, - "allow_amend": False, - "code": { - "encoding": encoding, - "entrypoint": entrypoint, - "ref": program_ref, - "use_latest": True, - }, - "on": triggers, - "environment": { - "reproducible": False, - "internet": True, - "aleph_api": True, - }, - "variables": environment_variables, - "resources": { - "vcpus": vcpus, - "memory": memory, - "seconds": timeout_seconds, - }, - "runtime": { - "ref": runtime, - "use_latest": True, - "comment": "Official aleph.im runtime" - if runtime == settings.DEFAULT_RUNTIME_ID - else "", - }, - "volumes": volumes, - "time": time.time(), - "metadata": metadata, - } + type="vm-function", + address=address, + allow_amend=False, + code=CodeContent( + encoding=encoding, + entrypoint=entrypoint, + ref=program_ref, + use_latest=True, + ), + on=triggers, + environment=FunctionEnvironment( + reproducible=False, + internet=True, + aleph_api=True, + ), + variables=environment_variables, + resources=MachineResources( + vcpus=vcpus, + memory=memory, + seconds=timeout_seconds, + ), + runtime=FunctionRuntime( + ref=runtime, + use_latest=True, + comment="Official aleph.im runtime" + if runtime == settings.DEFAULT_RUNTIME_ID + else "", + ), + volumes=volumes, + time=time.time(), + metadata=metadata, ) # Ensure that the version of aleph-message used supports the field. diff --git a/src/aleph/sdk/client/client.py b/src/aleph/sdk/client/http.py similarity index 69% rename from src/aleph/sdk/client/client.py rename to src/aleph/sdk/client/http.py index ce5c43db..93cbe837 100644 --- a/src/aleph/sdk/client/client.py +++ b/src/aleph/sdk/client/http.py @@ -1,20 +1,7 @@ -import asyncio import json import logging -import queue -import threading from io import BytesIO -from typing import ( - Any, - AsyncIterable, - Awaitable, - Callable, - Dict, - Iterable, - List, - Optional, - Type, -) +from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type import aiohttp from aleph_message import parse_message @@ -23,8 +10,8 @@ from ..conf import settings from ..exceptions import FileTooLarge, MessageNotFoundError, MultipleMessagesError -from ..models.message import MessageFilter, MessagesResponse -from ..models.post import Post, PostFilter, PostsResponse +from ..query.filters import MessageFilter, PostFilter +from ..query.responses import MessagesResponse, Post, PostsResponse from ..types import GenericMessage from ..utils import ( Writable, @@ -32,158 +19,12 @@ copy_async_readable_to_buffer, get_message_type_value, ) -from .base import BaseAlephClient -from .utils import T, wrap_async +from .abstract import AlephClient logger = logging.getLogger(__name__) -async def run_async_watcher( - *args, output_queue: queue.Queue, api_server: Optional[str], **kwargs -): - async with AlephClient(api_server=api_server) as session: - async for message in session.watch_messages(*args, **kwargs): - output_queue.put(message) - - -def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs): - asyncio.run( - run_async_watcher( - output_queue=output_queue, api_server=api_server, *args, **kwargs - ) - ) - - -class UserSessionSync: - """ - A sync version of `UserSession`, used in sync code. - - This class is returned by the context manager of `UserSession` and is - intended as a wrapper around the methods of `UserSession` and not as a public class. - The methods are fully typed to enable static type checking, but most (all) methods - should look like this (using args and kwargs for brevity, but the functions should - be fully typed): - - >>> def func(self, *args, **kwargs): - >>> return self._wrap(self.async_session.func)(*args, **kwargs) - """ - - def __init__(self, async_session: "AlephClient"): - self.async_session = async_session - - def _wrap(self, method: Callable[..., Awaitable[T]], *args, **kwargs): - return wrap_async(method)(*args, **kwargs) - - def get_messages( - self, - page_size: int = 200, - page: int = 1, - message_filter: Optional[MessageFilter] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, - ) -> MessagesResponse: - return self._wrap( - self.async_session.get_messages, - page_size=page_size, - page=page, - message_filter=message_filter, - ignore_invalid_messages=ignore_invalid_messages, - invalid_messages_log_level=invalid_messages_log_level, - ) - - # @async_wrapper - def get_message( - self, - item_hash: str, - message_type: Optional[Type[GenericMessage]] = None, - channel: Optional[str] = None, - ) -> GenericMessage: - return self._wrap( - self.async_session.get_message, - item_hash=item_hash, - message_type=message_type, - channel=channel, - ) - - def fetch_aggregate( - self, - address: str, - key: str, - ) -> Dict[str, Dict]: - return self._wrap(self.async_session.fetch_aggregate, address, key) - - def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - ) -> Dict[str, Dict]: - return self._wrap(self.async_session.fetch_aggregates, address, keys) - - def get_posts( - self, - page_size: int = 200, - page: int = 1, - post_filter: Optional[PostFilter] = None, - ) -> PostsResponse: - return self._wrap( - self.async_session.get_posts, - page_size=page_size, - page=page, - post_filter=post_filter, - ) - - def download_file(self, file_hash: str) -> bytes: - return self._wrap(self.async_session.download_file, file_hash=file_hash) - - def download_file_ipfs(self, file_hash: str) -> bytes: - return self._wrap( - self.async_session.download_file_ipfs, - file_hash=file_hash, - ) - - def download_file_to_buffer( - self, file_hash: str, output_buffer: Writable[bytes] - ) -> None: - return self._wrap( - self.async_session.download_file_to_buffer, - file_hash=file_hash, - output_buffer=output_buffer, - ) - - def download_file_ipfs_to_buffer( - self, file_hash: str, output_buffer: Writable[bytes] - ) -> None: - return self._wrap( - self.async_session.download_file_ipfs_to_buffer, - file_hash=file_hash, - output_buffer=output_buffer, - ) - - def watch_messages( - self, - message_filter: Optional[MessageFilter] = None, - ) -> Iterable[AlephMessage]: - """ - Iterate over current and future matching messages synchronously. - - Runs the `watch_messages` asynchronous generator in a thread. - """ - output_queue: queue.Queue[AlephMessage] = queue.Queue() - thread = threading.Thread( - target=watcher_thread, - args=( - output_queue, - self.async_session.api_server, - message_filter, - {}, - ), - ) - thread.start() - while True: - yield output_queue.get() - - -class AlephClient(BaseAlephClient): +class AlephHttpClient(AlephClient): api_server: str http_session: aiohttp.ClientSession @@ -221,18 +62,7 @@ def __init__( ) ) - def __enter__(self) -> UserSessionSync: - return UserSessionSync(async_session=self) - - def __exit__(self, exc_type, exc_val, exc_tb): - close_fut = self.http_session.close() - try: - loop = asyncio.get_running_loop() - loop.run_until_complete(close_fut) - except RuntimeError: - asyncio.run(close_fut) - - async def __aenter__(self) -> "AlephClient": + async def __aenter__(self) -> "AlephHttpClient": return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/src/aleph/sdk/client/utils.py b/src/aleph/sdk/client/utils.py deleted file mode 100644 index 8af8cec3..00000000 --- a/src/aleph/sdk/client/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio -from typing import Awaitable, Callable, TypeVar - -T = TypeVar("T") - - -def wrap_async(func: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """Wrap an asynchronous function into a synchronous one, - for easy use in synchronous code. - """ - - def func_caller(*args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - - # Copy wrapped function interface: - func_caller.__doc__ = func.__doc__ - func_caller.__annotations__ = func.__annotations__ - func_caller.__defaults__ = func.__defaults__ - func_caller.__kwdefaults__ = func.__kwdefaults__ - return func_caller diff --git a/src/aleph/sdk/models/__init__.py b/src/aleph/sdk/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/aleph/sdk/models/common.py b/src/aleph/sdk/models/common.py deleted file mode 100644 index 6892d881..00000000 --- a/src/aleph/sdk/models/common.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class PaginationResponse(BaseModel): - pagination_page: int - pagination_total: int - pagination_per_page: int - pagination_item: str diff --git a/src/aleph/sdk/models/post.py b/src/aleph/sdk/models/post.py deleted file mode 100644 index 16669faf..00000000 --- a/src/aleph/sdk/models/post.py +++ /dev/null @@ -1,123 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Union - -from aleph_message.models import Chain, ItemHash, ItemType, MessageConfirmation -from pydantic import BaseModel, Field - -from .common import PaginationResponse -from .utils import _date_field_to_float, serialize_list - - -class Post(BaseModel): - """ - A post is a type of message that can be updated. Over the get_posts API - we get the latest version of a post. - """ - - chain: Chain = Field(description="Blockchain this post is associated with") - item_hash: ItemHash = Field(description="Unique hash for this post") - sender: str = Field(description="Address of the sender") - type: str = Field(description="Type of the POST message") - channel: Optional[str] = Field(description="Channel this post is associated with") - confirmed: bool = Field(description="Whether the post is confirmed or not") - content: Dict[str, Any] = Field(description="The content of the POST message") - item_content: Optional[str] = Field( - description="The POSTs content field as serialized JSON, if of type inline" - ) - item_type: ItemType = Field( - description="Type of the item content, usually 'inline' or 'storage' for POSTs" - ) - signature: Optional[str] = Field( - description="Cryptographic signature of the message by the sender" - ) - size: int = Field(description="Size of the post") - time: float = Field(description="Timestamp of the post") - confirmations: List[MessageConfirmation] = Field( - description="Number of confirmations" - ) - original_item_hash: ItemHash = Field(description="Hash of the original content") - original_signature: Optional[str] = Field( - description="Cryptographic signature of the original message" - ) - original_type: str = Field(description="The original type of the message") - hash: ItemHash = Field(description="Hash of the original item") - ref: Optional[Union[str, Any]] = Field( - description="Other message referenced by this one" - ) - - class Config: - allow_extra = False - - -class PostsResponse(PaginationResponse): - """Response from an aleph.im node API on the path /api/v0/posts.json""" - - posts: List[Post] - pagination_item = "posts" - - -class PostFilter: - """ - A collection of filters that can be applied on post queries. - - """ - - types: Optional[Iterable[str]] - refs: Optional[Iterable[str]] - addresses: Optional[Iterable[str]] - tags: Optional[Iterable[str]] - hashes: Optional[Iterable[str]] - channels: Optional[Iterable[str]] - chains: Optional[Iterable[str]] - start_date: Optional[Union[datetime, float]] - end_date: Optional[Union[datetime, float]] - - def __init__( - self, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, - ): - self.types = types - self.refs = refs - self.addresses = addresses - self.tags = tags - self.hashes = hashes - self.channels = channels - self.chains = chains - self.start_date = start_date - self.end_date = end_date - - def as_http_params(self) -> Dict[str, str]: - """Convert the filters into a dict that can be used by an `aiohttp` client - as `params` to build the HTTP query string. - """ - - partial_result = { - "types": serialize_list(self.types), - "refs": serialize_list(self.refs), - "addresses": serialize_list(self.addresses), - "tags": serialize_list(self.tags), - "hashes": serialize_list(self.hashes), - "channels": serialize_list(self.channels), - "chains": serialize_list(self.chains), - "startDate": _date_field_to_float(self.start_date), - "endDate": _date_field_to_float(self.end_date), - } - - # Ensure all values are strings. - result: Dict[str, str] = {} - - # Drop empty values - for key, value in partial_result.items(): - if value: - assert isinstance(value, str), f"Value must be a string: `{value}`" - result[key] = value - - return result diff --git a/src/aleph/sdk/models/utils.py b/src/aleph/sdk/models/utils.py deleted file mode 100644 index 818e4c70..00000000 --- a/src/aleph/sdk/models/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from datetime import datetime -from typing import Iterable, Optional, Union - - -def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: - if values: - return ",".join(values) - else: - return None - - -def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: - if date is None: - return None - elif isinstance(date, float): - return date - elif hasattr(date, "timestamp"): - return date.timestamp() - else: - raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/src/aleph/sdk/models/message.py b/src/aleph/sdk/query/filters.py similarity index 59% rename from src/aleph/sdk/models/message.py rename to src/aleph/sdk/query/filters.py index 4fdb295d..346f3a24 100644 --- a/src/aleph/sdk/models/message.py +++ b/src/aleph/sdk/query/filters.py @@ -1,17 +1,9 @@ from datetime import datetime -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, Optional, Union -from aleph_message.models import AlephMessage, MessageType +from aleph_message.models import MessageType -from .common import PaginationResponse -from .utils import _date_field_to_float, serialize_list - - -class MessagesResponse(PaginationResponse): - """Response from an aleph.im node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] - pagination_item = "messages" +from ..utils import _date_field_to_float, serialize_list class MessageFilter: @@ -101,3 +93,70 @@ def as_http_params(self) -> Dict[str, str]: result[key] = value return result + + +class PostFilter: + """ + A collection of filters that can be applied on post queries. + + """ + + types: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.types = types + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "types": serialize_list(self.types), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/src/aleph/sdk/query/responses.py b/src/aleph/sdk/query/responses.py new file mode 100644 index 00000000..5fb91804 --- /dev/null +++ b/src/aleph/sdk/query/responses.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + +from aleph_message.models import ( + AlephMessage, + Chain, + ItemHash, + ItemType, + MessageConfirmation, +) +from pydantic import BaseModel, Field + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + chain: Chain = Field(description="Blockchain this post is associated with") + item_hash: ItemHash = Field(description="Unique hash for this post") + sender: str = Field(description="Address of the sender") + type: str = Field(description="Type of the POST message") + channel: Optional[str] = Field(description="Channel this post is associated with") + confirmed: bool = Field(description="Whether the post is confirmed or not") + content: Dict[str, Any] = Field(description="The content of the POST message") + item_content: Optional[str] = Field( + description="The POSTs content field as serialized JSON, if of type inline" + ) + item_type: ItemType = Field( + description="Type of the item content, usually 'inline' or 'storage' for POSTs" + ) + signature: Optional[str] = Field( + description="Cryptographic signature of the message by the sender" + ) + size: int = Field(description="Size of the post") + time: float = Field(description="Timestamp of the post") + confirmations: List[MessageConfirmation] = Field( + description="Number of confirmations" + ) + original_item_hash: ItemHash = Field(description="Hash of the original content") + original_signature: Optional[str] = Field( + description="Cryptographic signature of the original message" + ) + original_type: str = Field(description="The original type of the message") + hash: ItemHash = Field(description="Hash of the original item") + ref: Optional[Union[str, Any]] = Field( + description="Other message referenced by this one" + ) + + class Config: + allow_extra = False + + +class PaginationResponse(BaseModel): + pagination_page: int + pagination_total: int + pagination_per_page: int + pagination_item: str + + +class PostsResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" + + +class MessagesResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index be56cc2c..810d7326 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,10 +1,11 @@ import errno import logging import os +from datetime import datetime from enum import Enum from pathlib import Path from shutil import make_archive -from typing import Protocol, Tuple, Type, TypeVar, Union +from typing import Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union from zipfile import BadZipFile, ZipFile from aleph_message.models import MessageType @@ -116,3 +117,21 @@ def enum_as_str(obj: Union[str, Enum]) -> str: return obj.value return obj + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/tests/integration/itest_aggregates.py b/tests/integration/itest_aggregates.py index 5c5d4648..31f5c6cc 100644 --- a/tests/integration/itest_aggregates.py +++ b/tests/integration/itest_aggregates.py @@ -3,7 +3,7 @@ import pytest -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.types import Account from tests.integration.toolkit import try_until @@ -18,7 +18,7 @@ async def create_aggregate_on_target( receiver_node: str, channel="INTEGRATION_TESTS", ): - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: aggregate_message, message_status = await tx_session.create_aggregate( @@ -38,7 +38,7 @@ async def create_aggregate_on_target( assert aggregate_message.content.address == account.get_address() assert aggregate_message.content.content == content - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=receiver_node ) as rx_session: aggregate_from_receiver = await try_until( diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index a457bdda..a6cc141c 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -4,8 +4,8 @@ import pytest from aleph_message.models import ItemHash -from aleph.sdk.client import AuthenticatedAlephClient -from aleph.sdk.models.message import MessageFilter +from aleph.sdk.client import AuthenticatedAlephHttpClient +from aleph.sdk.query.filters import MessageFilter from aleph.sdk.types import Account from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL @@ -15,7 +15,7 @@ async def create_and_forget_post( account: Account, emitter_node: str, receiver_node: str, channel=TEST_CHANNEL ) -> Tuple[ItemHash, ItemHash]: - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: post_message, message_status = await tx_session.create_post( @@ -24,7 +24,7 @@ async def create_and_forget_post( channel="INTEGRATION_TESTS", ) - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=receiver_node ) as rx_session: await try_until( @@ -38,7 +38,7 @@ async def create_and_forget_post( post_hash = post_message.item_hash reason = "This well thought-out content offends me!" - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: forget_message, forget_status = await tx_session.forget( @@ -52,7 +52,7 @@ async def create_and_forget_post( forget_hash = forget_message.item_hash # Wait until the message is forgotten - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=receiver_node ) as rx_session: await try_until( @@ -104,7 +104,7 @@ async def test_forget_a_forget_message(fixture_account): post_hash, forget_hash = await create_and_forget_post( fixture_account, TARGET_NODE, REFERENCE_NODE ) - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=fixture_account, api_server=TARGET_NODE ) as tx_session: forget_message, forget_status = await tx_session.forget( @@ -118,7 +118,7 @@ async def test_forget_a_forget_message(fixture_account): # wait 5 seconds await asyncio.sleep(5) - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=fixture_account, api_server=REFERENCE_NODE ) as rx_session: get_forget_message_response = await try_until( diff --git a/tests/integration/itest_posts.py b/tests/integration/itest_posts.py index 59b96b1b..77b87b7f 100644 --- a/tests/integration/itest_posts.py +++ b/tests/integration/itest_posts.py @@ -1,7 +1,7 @@ import pytest -from aleph.sdk.client import AuthenticatedAlephClient -from aleph.sdk.models.message import MessageFilter +from aleph.sdk.client import AuthenticatedAlephHttpClient +from aleph.sdk.query.filters import MessageFilter from tests.integration.toolkit import has_messages, try_until from .config import REFERENCE_NODE, TARGET_NODE @@ -11,7 +11,7 @@ async def create_message_on_target(account, emitter_node: str, receiver_node: st """ Create a POST message on the target node, then fetch it from the reference node. """ - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=emitter_node ) as tx_session: post_message, message_status = await tx_session.create_post( @@ -20,7 +20,7 @@ async def create_message_on_target(account, emitter_node: str, receiver_node: st channel="INTEGRATION_TESTS", ) - async with AuthenticatedAlephClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=receiver_node ) as rx_session: responses = await try_until( diff --git a/tests/integration/toolkit.py b/tests/integration/toolkit.py index 62a5f841..a72f9d6f 100644 --- a/tests/integration/toolkit.py +++ b/tests/integration/toolkit.py @@ -2,7 +2,7 @@ import time from typing import Awaitable, Callable, TypeVar -from aleph.sdk.models.message import MessagesResponse +from aleph.sdk.query.responses import MessagesResponse T = TypeVar("T") diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index 8973263b..dbccbaa6 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -11,14 +11,14 @@ ) from aleph_message.status import MessageStatus -from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.types import Account, StorageEnum @pytest.fixture def mock_session_with_post_success( ethereum_account: Account, -) -> AuthenticatedAlephClient: +) -> AuthenticatedAlephHttpClient: class MockResponse: def __init__(self, sync: bool): self.sync = sync @@ -49,7 +49,7 @@ async def text(self): sync=kwargs.get("sync", False) ) - client = AuthenticatedAlephClient( + client = AuthenticatedAlephHttpClient( account=ethereum_account, api_server="http://localhost" ) client.http_session = http_session diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 7773f9b2..f5e0c800 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -5,13 +5,13 @@ import pytest from aleph_message.models import MessagesResponse, MessageType -from aleph.sdk import AlephClient +from aleph.sdk import AlephHttpClient from aleph.sdk.conf import settings -from aleph.sdk.models.message import MessageFilter -from aleph.sdk.models.post import PostFilter, PostsResponse +from aleph.sdk.query.filters import MessageFilter, PostFilter +from aleph.sdk.query.responses import PostsResponse -def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient: +def make_mock_session(get_return_value: Dict[str, Any]) -> AlephHttpClient: class MockResponse: async def __aenter__(self): return self @@ -35,7 +35,7 @@ def get(self, *_args, **_kwargs): http_session = MockHttpSession() - client = AlephClient(api_server="http://localhost") + client = AlephHttpClient(api_server="http://localhost") client.http_session = http_session return client @@ -70,7 +70,7 @@ async def test_fetch_aggregates(): @pytest.mark.asyncio async def test_get_posts(): - async with AlephClient(api_server=settings.API_HOST) as session: + async with AlephHttpClient(api_server=settings.API_HOST) as session: response: PostsResponse = await session.get_posts( page_size=2, post_filter=PostFilter( @@ -84,7 +84,7 @@ async def test_get_posts(): @pytest.mark.asyncio async def test_get_messages(): - async with AlephClient(api_server=settings.API_HOST) as session: + async with AlephHttpClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( page_size=2, message_filter=MessageFilter( diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index b16e0d75..377e6d41 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -1,6 +1,6 @@ import pytest -from aleph.sdk import AlephClient +from aleph.sdk import AlephHttpClient from aleph.sdk.conf import settings as sdk_settings @@ -13,7 +13,7 @@ ) @pytest.mark.asyncio async def test_download(file_hash: str, expected_size: int): - async with AlephClient(api_server=sdk_settings.API_HOST) as client: + async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client: file_content = await client.download_file(file_hash) # File is 5B file_size = len(file_content) assert file_size == expected_size @@ -28,7 +28,7 @@ async def test_download(file_hash: str, expected_size: int): ) @pytest.mark.asyncio async def test_download_ipfs(file_hash: str, expected_size: int): - async with AlephClient(api_server=sdk_settings.API_HOST) as client: + async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client: file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE file_size = len(file_content) assert file_size == expected_size diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py deleted file mode 100644 index e3b8c0ed..00000000 --- a/tests/unit/test_synchronous_get.py +++ /dev/null @@ -1,20 +0,0 @@ -from aleph_message.models import MessagesResponse, MessageType - -from aleph.sdk import AlephClient -from aleph.sdk.conf import settings -from aleph.sdk.models.message import MessageFilter - - -def test_get_post_messages(): - with AlephClient(api_server=settings.API_HOST) as session: - response: MessagesResponse = session.get_messages( - page_size=2, - message_filter=MessageFilter( - message_types=[MessageType.post], - ), - ) - - messages = response.messages - assert len(messages) > 1 - for message in messages: - assert message.type == MessageType.post From df269469a708aca642a0bdd2c06f5c21a81378f4 Mon Sep 17 00:00:00 2001 From: mhh Date: Tue, 10 Oct 2023 18:08:46 +0200 Subject: [PATCH 6/6] Fix: /examples directory was unmaintained and did not work correctly Solution: Debug examples and only use async client --- examples/metrics.py | 23 +++++++++++++---------- examples/mqtt.py | 4 ++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/metrics.py b/examples/metrics.py index c7beb5d2..d8f8a0cc 100644 --- a/examples/metrics.py +++ b/examples/metrics.py @@ -1,7 +1,6 @@ """ Server metrics upload. """ -# -*- coding: utf-8 -*- - +import asyncio import os import platform import time @@ -12,9 +11,11 @@ from aleph_message.status import MessageStatus from aleph.sdk.chains.ethereum import get_fallback_account -from aleph.sdk.client import AuthenticatedAlephClientSync, AuthenticatedAlephHttpClient +from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings +# -*- coding: utf-8 -*- + def get_sysinfo(): uptime = int(time.time() - psutil.boot_time()) @@ -53,10 +54,12 @@ def get_cpu_cores(): return [c._asdict() for c in psutil.cpu_times_percent(0, percpu=True)] -def send_metrics( - session: AuthenticatedAlephClientSync, metrics +async def send_metrics( + session: AuthenticatedAlephHttpClient, metrics ) -> Tuple[AlephMessage, MessageStatus]: - return session.create_aggregate(key="metrics", content=metrics, channel="SYSINFO") + return await session.create_aggregate( + key="metrics", content=metrics, channel="SYSINFO" + ) def collect_metrics(): @@ -68,17 +71,17 @@ def collect_metrics(): } -def main(): +async def main(): account = get_fallback_account() - with AuthenticatedAlephHttpClient( + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: while True: metrics = collect_metrics() - message, status = send_metrics(session, metrics) + message, status = await send_metrics(session, metrics) print("sent", message.item_hash) time.sleep(10) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/examples/mqtt.py b/examples/mqtt.py index e09b2c6f..b08538f9 100644 --- a/examples/mqtt.py +++ b/examples/mqtt.py @@ -26,8 +26,8 @@ def get_input_data(value): return value.decode("utf-8") -def send_metrics(account, metrics): - with AuthenticatedAlephHttpClient( +async def send_metrics(account, metrics): + async with AuthenticatedAlephHttpClient( account=account, api_server=settings.API_HOST ) as session: return session.create_aggregate(