Skip to content

Commit 603afef

Browse files
committed
add posts table for caching posts; handle amend messages; refactor node.py as a package
1 parent 30f505e commit 603afef

File tree

4 files changed

+372
-209
lines changed

4 files changed

+372
-209
lines changed

src/aleph/sdk/node.py src/aleph/sdk/node/__init__.py

+76-209
Original file line numberDiff line numberDiff line change
@@ -1,233 +1,57 @@
11
import asyncio
2-
import json
32
import logging
43
import typing
54
from datetime import datetime
6-
from functools import partial
75
from pathlib import Path
86
from typing import (
97
Any,
108
AsyncIterable,
119
Coroutine,
1210
Dict,
13-
Generic,
1411
Iterable,
1512
Iterator,
1613
List,
1714
Mapping,
1815
Optional,
1916
Tuple,
2017
Type,
21-
TypeVar,
2218
Union,
2319
)
2420

25-
from aleph_message import MessagesResponse, parse_message
26-
from aleph_message.models import (
27-
AlephMessage,
28-
Chain,
29-
ItemHash,
30-
MessageConfirmation,
31-
MessageType,
32-
)
21+
from aleph_message import MessagesResponse
22+
from aleph_message.models import AlephMessage, Chain, ItemHash, MessageType, PostMessage
3323
from aleph_message.models.execution.base import Encoding
3424
from aleph_message.status import MessageStatus
35-
from peewee import (
36-
BooleanField,
37-
CharField,
38-
FloatField,
39-
IntegerField,
40-
Model,
41-
SqliteDatabase,
42-
)
43-
from playhouse.shortcuts import model_to_dict
44-
from playhouse.sqlite_ext import JSONField
45-
from pydantic import BaseModel
46-
47-
from aleph.sdk import AuthenticatedAlephClient
48-
from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase
49-
from aleph.sdk.conf import settings
50-
from aleph.sdk.exceptions import MessageNotFoundError
51-
from aleph.sdk.models import PostsResponse
52-
from aleph.sdk.types import GenericMessage, StorageEnum
53-
54-
db = SqliteDatabase(settings.CACHE_DATABASE_PATH)
55-
T = TypeVar("T", bound=BaseModel)
56-
57-
58-
class JSONDictEncoder(json.JSONEncoder):
59-
def default(self, obj):
60-
if isinstance(obj, BaseModel):
61-
return obj.dict()
62-
return json.JSONEncoder.default(self, obj)
63-
64-
65-
pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder)
66-
67-
68-
class PydanticField(JSONField, Generic[T]):
69-
"""
70-
A field for storing pydantic model types as JSON in a database. Uses json for serialization.
71-
"""
72-
73-
type: T
7425

75-
def __init__(self, *args, **kwargs):
76-
self.type = kwargs.pop("type")
77-
super().__init__(*args, **kwargs)
26+
from ..base import BaseAlephClient, BaseAuthenticatedAlephClient
27+
from ..client import AuthenticatedAlephClient
28+
from ..conf import settings
29+
from ..exceptions import MessageNotFoundError
30+
from ..models import PostsResponse
31+
from ..types import GenericMessage, StorageEnum
32+
from .common import db
33+
from .message import MessageModel, get_message_query, message_to_model, model_to_message
34+
from .post import PostModel, get_post_query, message_to_post, model_to_post
7835

79-
def db_value(self, value: Optional[T]) -> Optional[str]:
80-
if value is None:
81-
return None
82-
return value.json()
83-
84-
def python_value(self, value: Optional[str]) -> Optional[T]:
85-
if value is None:
86-
return None
87-
return self.type.parse_raw(value)
88-
89-
90-
class MessageModel(Model):
91-
"""
92-
A simple database model for storing AlephMessage objects.
93-
"""
9436

95-
item_hash = CharField(primary_key=True)
96-
chain = CharField(5)
97-
type = CharField(9)
98-
sender = CharField()
99-
channel = CharField(null=True)
100-
confirmations: PydanticField[MessageConfirmation] = PydanticField(
101-
type=MessageConfirmation, null=True
102-
)
103-
confirmed = BooleanField(null=True)
104-
signature = CharField(null=True)
105-
size = IntegerField(null=True)
106-
time = FloatField()
107-
item_type = CharField(7)
108-
item_content = CharField(null=True)
109-
hash_type = CharField(6, null=True)
110-
content = JSONField(json_dumps=pydantic_json_dumps)
111-
forgotten_by = CharField(null=True)
112-
tags = JSONField(json_dumps=pydantic_json_dumps, null=True)
113-
key = CharField(null=True)
114-
ref = CharField(null=True)
115-
content_type = CharField(null=True)
116-
117-
class Meta:
118-
database = db
119-
120-
121-
def message_to_model(message: AlephMessage) -> Dict:
122-
return {
123-
"item_hash": str(message.item_hash),
124-
"chain": message.chain,
125-
"type": message.type,
126-
"sender": message.sender,
127-
"channel": message.channel,
128-
"confirmations": message.confirmations[0] if message.confirmations else None,
129-
"confirmed": message.confirmed,
130-
"signature": message.signature,
131-
"size": message.size,
132-
"time": message.time,
133-
"item_type": message.item_type,
134-
"item_content": message.item_content,
135-
"hash_type": message.hash_type,
136-
"content": message.content,
137-
"forgotten_by": message.forgotten_by[0] if message.forgotten_by else None,
138-
"tags": message.content.content.get("tags", None)
139-
if hasattr(message.content, "content")
140-
else None,
141-
"key": message.content.key if hasattr(message.content, "key") else None,
142-
"ref": message.content.ref if hasattr(message.content, "ref") else None,
143-
"content_type": message.content.type
144-
if hasattr(message.content, "type")
145-
else None,
146-
}
147-
148-
149-
def model_to_message(item: Any) -> AlephMessage:
150-
item.confirmations = [item.confirmations] if item.confirmations else []
151-
item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None
152-
153-
to_exclude = [
154-
MessageModel.tags,
155-
MessageModel.ref,
156-
MessageModel.key,
157-
MessageModel.content_type,
158-
]
159-
160-
item_dict = model_to_dict(item, exclude=to_exclude)
161-
return parse_message(item_dict)
162-
163-
164-
def query_field(field_name, field_values: Iterable[str]):
165-
field = getattr(MessageModel, field_name)
166-
values = list(field_values)
167-
168-
if len(values) == 1:
169-
return field == values[0]
170-
return field.in_(values)
171-
172-
173-
def get_message_query(
174-
message_type: Optional[MessageType] = None,
175-
content_keys: Optional[Iterable[str]] = None,
176-
content_types: Optional[Iterable[str]] = None,
177-
refs: Optional[Iterable[str]] = None,
178-
addresses: Optional[Iterable[str]] = None,
179-
tags: Optional[Iterable[str]] = None,
180-
hashes: Optional[Iterable[str]] = None,
181-
channels: Optional[Iterable[str]] = None,
182-
chains: Optional[Iterable[str]] = None,
183-
start_date: Optional[Union[datetime, float]] = None,
184-
end_date: Optional[Union[datetime, float]] = None,
185-
):
186-
query = MessageModel.select().order_by(MessageModel.time.desc())
187-
conditions = []
188-
if message_type:
189-
conditions.append(query_field("type", [message_type.value]))
190-
if content_keys:
191-
conditions.append(query_field("key", content_keys))
192-
if content_types:
193-
conditions.append(query_field("content_type", content_types))
194-
if refs:
195-
conditions.append(query_field("ref", refs))
196-
if addresses:
197-
conditions.append(query_field("sender", addresses))
198-
if tags:
199-
for tag in tags:
200-
conditions.append(MessageModel.tags.contains(tag))
201-
if hashes:
202-
conditions.append(query_field("item_hash", hashes))
203-
if channels:
204-
conditions.append(query_field("channel", channels))
205-
if chains:
206-
conditions.append(query_field("chain", chains))
207-
if start_date:
208-
conditions.append(MessageModel.time >= start_date)
209-
if end_date:
210-
conditions.append(MessageModel.time <= end_date)
211-
212-
if conditions:
213-
query = query.where(*conditions)
214-
return query
215-
216-
217-
class MessageCache(AlephClientBase):
37+
class MessageCache(BaseAlephClient):
21838
"""
21939
A wrapper around a sqlite3 database for caching AlephMessage objects.
22040
22141
It can be used independently of a DomainNode to implement any kind of caching strategy.
22242
"""
22343

22444
_instance_count = 0 # Class-level counter for active instances
45+
missing_posts: Dict[ItemHash, PostMessage] = {}
46+
"""A dict of all posts by item_hash and their amend messages that are missing from the cache."""
22547

22648
def __init__(self):
22749
if db.is_closed():
22850
db.connect()
22951
if not MessageModel.table_exists():
23052
db.create_tables([MessageModel])
53+
if not PostModel.table_exists():
54+
db.create_tables([PostModel])
23155

23256
MessageCache._instance_count += 1
23357

@@ -270,17 +94,57 @@ def __repr__(self) -> str:
27094
def __str__(self) -> str:
27195
return repr(self)
27296

273-
@staticmethod
274-
def add(messages: Union[AlephMessage, Iterable[AlephMessage]]):
97+
def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]):
27598
if isinstance(messages, typing.get_args(AlephMessage)):
27699
messages = [messages]
277100

278-
data_source = (message_to_model(message) for message in messages)
279-
MessageModel.insert_many(data_source).on_conflict_replace().execute()
101+
message_data = (message_to_model(message) for message in messages)
102+
MessageModel.insert_many(message_data).on_conflict_replace().execute()
103+
104+
# Add posts and their amends to the PostModel
105+
post_data = []
106+
amend_messages = []
107+
for message in messages:
108+
if message.item_type != MessageType.post:
109+
continue
110+
if message.content.type == "amend":
111+
amend_messages.append(message)
112+
else:
113+
post = message_to_post(message).dict()
114+
post_data.append(post)
115+
# Check if we can now add any amend messages that had missing refs
116+
if message.item_hash in self.missing_posts:
117+
amend_messages += self.missing_posts.pop(message.item_hash)
118+
119+
PostModel.insert_many(post_data).on_conflict_replace().execute()
120+
121+
# Handle amends in second step to avoid missing original posts
122+
post_data = []
123+
for message in amend_messages:
124+
# Find the original post and update it
125+
original_post = MessageModel.get(
126+
MessageModel.item_hash == message.content.ref
127+
)
128+
if not original_post:
129+
latest_amend = self.missing_posts.get(ItemHash(message.content.ref))
130+
if latest_amend and message.time < latest_amend.time:
131+
self.missing_posts[ItemHash(message.content.ref)] = message
132+
continue
133+
if datetime.fromtimestamp(message.time) < original_post.last_updated:
134+
continue
135+
original_post.item_hash = message.item_hash
136+
original_post.content = message.content.content
137+
original_post.original_item_hash = message.content.ref
138+
original_post.original_type = message.content.type
139+
original_post.address = message.sender
140+
original_post.channel = message.channel
141+
original_post.last_updated = datetime.fromtimestamp(message.time)
142+
post_data.append(original_post)
143+
144+
PostModel.insert_many(post_data).on_conflict_replace().execute()
280145

281-
@staticmethod
282146
def get(
283-
item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]]
147+
self, item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]]
284148
) -> List[AlephMessage]:
285149
"""
286150
Get many messages from the cache by their item hash.
@@ -347,12 +211,11 @@ async def get_posts(
347211
chains: Optional[Iterable[str]] = None,
348212
start_date: Optional[Union[datetime, float]] = None,
349213
end_date: Optional[Union[datetime, float]] = None,
350-
ignore_invalid_messages: bool = True,
351-
invalid_messages_log_level: int = logging.NOTSET,
214+
ignore_invalid_messages: Optional[bool] = True,
215+
invalid_messages_log_level: Optional[int] = logging.NOTSET,
352216
) -> PostsResponse:
353-
query = get_message_query(
354-
message_type=MessageType.post,
355-
content_types=types,
217+
query = get_post_query(
218+
types=types,
356219
refs=refs,
357220
addresses=addresses,
358221
tags=tags,
@@ -365,7 +228,7 @@ async def get_posts(
365228

366229
query = query.paginate(page, pagination)
367230

368-
posts = [model_to_message(item) for item in list(query)]
231+
posts = [model_to_post(item) for item in list(query)]
369232

370233
return PostsResponse(
371234
posts=posts,
@@ -383,6 +246,7 @@ async def get_messages(
383246
pagination: int = 200,
384247
page: int = 1,
385248
message_type: Optional[MessageType] = None,
249+
message_types: Optional[Iterable[MessageType]] = None,
386250
content_types: Optional[Iterable[str]] = None,
387251
content_keys: Optional[Iterable[str]] = None,
388252
refs: Optional[Iterable[str]] = None,
@@ -393,14 +257,15 @@ async def get_messages(
393257
chains: Optional[Iterable[str]] = None,
394258
start_date: Optional[Union[datetime, float]] = None,
395259
end_date: Optional[Union[datetime, float]] = None,
396-
ignore_invalid_messages: bool = True,
397-
invalid_messages_log_level: int = logging.NOTSET,
260+
ignore_invalid_messages: Optional[bool] = True,
261+
invalid_messages_log_level: Optional[int] = logging.NOTSET,
398262
) -> MessagesResponse:
399263
"""
400264
Get many messages from the cache.
401265
"""
266+
message_types = message_types or [message_type] if message_type else None
402267
query = get_message_query(
403-
message_type=message_type,
268+
message_types=message_types,
404269
content_keys=content_keys,
405270
content_types=content_types,
406271
refs=refs,
@@ -451,6 +316,7 @@ async def get_message(
451316
async def watch_messages(
452317
self,
453318
message_type: Optional[MessageType] = None,
319+
message_types: Optional[Iterable[MessageType]] = None,
454320
content_types: Optional[Iterable[str]] = None,
455321
content_keys: Optional[Iterable[str]] = None,
456322
refs: Optional[Iterable[str]] = None,
@@ -465,8 +331,9 @@ async def watch_messages(
465331
"""
466332
Watch messages from the cache.
467333
"""
334+
message_types = message_types or [message_type] if message_type else None
468335
query = get_message_query(
469-
message_type=message_type,
336+
message_types=message_types,
470337
content_keys=content_keys,
471338
content_types=content_types,
472339
refs=refs,
@@ -483,7 +350,7 @@ async def watch_messages(
483350
yield model_to_message(item)
484351

485352

486-
class DomainNode(MessageCache, AuthenticatedAlephClientBase):
353+
class DomainNode(MessageCache, BaseAuthenticatedAlephClient):
487354
"""
488355
A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph
489356
network.

0 commit comments

Comments
 (0)