Skip to content

Commit fcd8efe

Browse files
authored
fix/test: ensure polls are being correctly processed (#1714)
* fix: use answer_id from data, not options * fix: correctly deserialize question for polls * ci: add tests for polls * test: make poll dict test more resilient
1 parent 42df28b commit fcd8efe

File tree

3 files changed

+95
-4
lines changed

3 files changed

+95
-4
lines changed

interactions/api/events/processors/message_events.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None:
9999
event.data["channel_id"],
100100
event.data["message_id"],
101101
event.data["user_id"],
102-
event.data["option"],
102+
event.data["answer_id"],
103103
)
104104
)
105105

@@ -118,6 +118,6 @@ async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> No
118118
event.data["channel_id"],
119119
event.data["message_id"],
120120
event.data["user_id"],
121-
event.data["option"],
121+
event.data["answer_id"],
122122
)
123123
)

interactions/models/discord/poll.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class PollResults(DictSerializationMixin):
8686

8787
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
8888
class Poll(DictSerializationMixin):
89-
question: PollMedia = attrs.field(repr=False)
89+
question: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict)
9090
"""The question of the poll. Only text media is supported."""
9191
answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list)
9292
"""Each of the answers available in the poll, up to 10."""

tests/test_bot.py

+92-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from asyncio import AbstractEventLoop
55
from contextlib import suppress
6-
from datetime import datetime
6+
from datetime import datetime, timedelta
77

88
import pytest
99
import pytest_asyncio
@@ -33,6 +33,8 @@
3333
ParagraphText,
3434
Message,
3535
GuildVoice,
36+
Poll,
37+
PollMedia,
3638
)
3739
from interactions.models.discord.asset import Asset
3840
from interactions.models.discord.components import ActionRow, Button, StringSelectMenu
@@ -432,6 +434,95 @@ async def test_components(bot: Client, channel: GuildText) -> None:
432434
await thread.delete()
433435

434436

437+
@pytest.mark.asyncio
438+
async def test_polls(bot: Client, channel: GuildText) -> None:
439+
msg = await channel.send("Polls Tests")
440+
thread = await msg.create_thread("Test Thread")
441+
442+
try:
443+
poll_1 = Poll.create("Test Poll", duration=1, answers=["Answer 1", "Answer 2"])
444+
test_data_1 = {
445+
"question": {"text": "Test Poll"},
446+
"layout_type": 1,
447+
"duration": 1,
448+
"allow_multiselect": False,
449+
"answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}],
450+
}
451+
poll_1_dict = poll_1.to_dict()
452+
for key in poll_1_dict.keys():
453+
assert poll_1_dict[key] == test_data_1[key]
454+
455+
msg_1 = await thread.send(poll=poll_1)
456+
457+
assert msg_1.poll is not None
458+
assert msg_1.poll.question.to_dict() == PollMedia(text="Test Poll").to_dict()
459+
assert msg_1.poll.expiry <= msg_1.created_at + timedelta(hours=1, minutes=1)
460+
poll_1_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_1.poll.answers]
461+
assert poll_1_answer_medias == [
462+
PollMedia.create(text="Answer 1").to_dict(),
463+
PollMedia.create(text="Answer 2").to_dict(),
464+
]
465+
466+
poll_2 = Poll.create("Test Poll 2", duration=1, allow_multiselect=True)
467+
poll_2.add_answer("Answer 1")
468+
poll_2.add_answer("Answer 2")
469+
test_data_2 = {
470+
"question": {"text": "Test Poll 2"},
471+
"layout_type": 1,
472+
"duration": 1,
473+
"allow_multiselect": True,
474+
"answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}],
475+
}
476+
poll_2_dict = poll_2.to_dict()
477+
for key in poll_2_dict.keys():
478+
assert poll_2_dict[key] == test_data_2[key]
479+
msg_2 = await thread.send(poll=poll_2)
480+
481+
assert msg_2.poll is not None
482+
assert msg_2.poll.question.to_dict() == PollMedia(text="Test Poll 2").to_dict()
483+
assert msg_2.poll.expiry <= msg_2.created_at + timedelta(hours=1, minutes=1)
484+
assert msg_2.poll.allow_multiselect
485+
poll_2_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_2.poll.answers]
486+
assert poll_2_answer_medias == [
487+
PollMedia.create(text="Answer 1").to_dict(),
488+
PollMedia.create(text="Answer 2").to_dict(),
489+
]
490+
491+
poll_3 = Poll.create(
492+
"Test Poll 3",
493+
duration=1,
494+
answers=[PollMedia.create(text="One", emoji="1️⃣"), PollMedia.create(text="Two", emoji="2️⃣")],
495+
)
496+
test_data_3 = {
497+
"question": {"text": "Test Poll 3"},
498+
"layout_type": 1,
499+
"duration": 1,
500+
"allow_multiselect": False,
501+
"answers": [
502+
{"poll_media": {"text": "One", "emoji": {"name": "1️⃣", "animated": False}}},
503+
{"poll_media": {"text": "Two", "emoji": {"name": "2️⃣", "animated": False}}},
504+
],
505+
}
506+
poll_3_dict = poll_3.to_dict()
507+
for key in poll_3_dict.keys():
508+
assert poll_3_dict[key] == test_data_3[key]
509+
510+
msg_3 = await thread.send(poll=poll_3)
511+
512+
assert msg_3.poll is not None
513+
assert msg_3.poll.question.to_dict() == PollMedia(text="Test Poll 3").to_dict()
514+
assert msg_3.poll.expiry <= msg_3.created_at + timedelta(hours=1, minutes=1)
515+
poll_3_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_3.poll.answers]
516+
assert poll_3_answer_medias == [
517+
PollMedia.create(text="One", emoji="1️⃣").to_dict(),
518+
PollMedia.create(text="Two", emoji="2️⃣").to_dict(),
519+
]
520+
521+
finally:
522+
with suppress(interactions.errors.NotFound):
523+
await thread.delete()
524+
525+
435526
@pytest.mark.asyncio
436527
async def test_webhooks(bot: Client, guild: Guild, channel: GuildText) -> None:
437528
test_thread = await channel.create_thread("Test Thread")

0 commit comments

Comments
 (0)