Skip to content

Commit f4e71f1

Browse files
committed
fix post problems
1 parent 2c2eadf commit f4e71f1

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

src/aleph/sdk/models.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@ class Post(BaseModel):
3737
)
3838
address: str = Field(description="The address of the sender of the POST message")
3939
ref: Optional[str] = Field(description="Other message referenced by this one")
40-
channel: str = Field(description="The channel where the POST message was published")
40+
channel: Optional[str] = Field(description="The channel where the POST message was published")
4141
created: datetime = Field(description="The time when the POST message was created")
4242
last_updated: datetime = Field(
4343
description="The time when the POST message was last updated"
4444
)
4545

46+
class Config:
47+
allow_extra = False
48+
4649

4750
class PostsResponse(PaginationResponse):
4851
"""Response from an Aleph node API on the path /api/v0/posts.json"""

src/aleph/sdk/node/__init__.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,27 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]):
9898
if isinstance(messages, typing.get_args(AlephMessage)):
9999
messages = [messages]
100100

101+
messages = list(messages)
102+
101103
message_data = (message_to_model(message) for message in messages)
102104
MessageModel.insert_many(message_data).on_conflict_replace().execute()
103105

104106
# Add posts and their amends to the PostModel
105107
post_data = []
106108
amend_messages = []
107109
for message in messages:
108-
if message.item_type != MessageType.post:
110+
if message.type != MessageType.post.value:
109111
continue
110112
if message.content.type == "amend":
111113
amend_messages.append(message)
112-
else:
113-
post = message_to_post(message).dict()
114-
post_data.append(post)
115-
# Check if we can now add any amend messages that had missing refs
116-
if message.item_hash in self.missing_posts:
117-
amend_messages += self.missing_posts.pop(message.item_hash)
114+
continue
115+
post = message_to_post(message).dict()
116+
post["chain"] = message.chain.value
117+
post["tags"] = message.content.content.get("tags", None)
118+
post_data.append(post)
119+
# Check if we can now add any amend messages that had missing refs
120+
if message.item_hash in self.missing_posts:
121+
amend_messages += self.missing_posts.pop(message.item_hash)
118122

119123
PostModel.insert_many(post_data).on_conflict_replace().execute()
120124

src/aleph/sdk/node/post.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def message_to_post(message: PostMessage) -> Post:
5858
"ref": message.content.ref if hasattr(message.content, "ref") else None,
5959
"channel": message.channel,
6060
"created": datetime.fromtimestamp(message.time),
61-
"last_updated": datetime.fromtimestamp(message.time),
61+
"last_updated": datetime.fromtimestamp(message.time)
6262
}
6363
)
6464

tests/unit/test_node_get.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from aleph.sdk.chains.ethereum import get_fallback_account
1515
from aleph.sdk.exceptions import MessageNotFoundError
16-
from aleph.sdk.node import MessageCache
16+
from aleph.sdk.node import MessageCache, message_to_post
1717

1818

1919
@pytest.mark.asyncio
@@ -137,7 +137,7 @@ def class_teardown(self):
137137
@pytest.mark.asyncio
138138
async def test_addresses(self):
139139
items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts
140-
assert items[0] == self.messages[1]
140+
assert items[0] == message_to_post(self.messages[1])
141141

142142
@pytest.mark.asyncio
143143
async def test_tags(self):
@@ -153,15 +153,16 @@ async def test_types(self):
153153

154154
@pytest.mark.asyncio
155155
async def test_channels(self):
156+
print(self.messages[1])
156157
assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[
157158
0
158-
] == self.messages[1]
159+
] == message_to_post(self.messages[1])
159160

160161
@pytest.mark.asyncio
161162
async def test_chains(self):
162163
assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[
163164
0
164-
] == self.messages[1]
165+
] == message_to_post(self.messages[1])
165166

166167

167168
@pytest.mark.asyncio

0 commit comments

Comments
 (0)