Skip to content

Commit 4137eba

Browse files
committed
tweak rec algo
1 parent 78c93f9 commit 4137eba

File tree

3 files changed

+59
-36
lines changed

3 files changed

+59
-36
lines changed

src/recommendations/candidates.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,9 @@ async def get_fast_dopamine(
598598
ON M.id = E.meme_id
599599
WHERE 1=1
600600
AND reaction_id = 1
601-
AND reacted_at - sent_at BETWEEN '0.5 second' AND '30 seconds'
601+
AND reacted_at - sent_at BETWEEN '2 seconds' AND '20 seconds'
602602
GROUP BY 1
603-
HAVING COUNT(*) >= 1
603+
HAVING COUNT(*) >= 3
604604
)
605605
606606
SELECT
@@ -624,7 +624,7 @@ async def get_fast_dopamine(
624624
ON R.meme_id = M.id
625625
AND R.user_id = {user_id}
626626
627-
LEFT JOIN MEME_SEC_TO_LIKE MSTL
627+
INNER JOIN MEME_SEC_TO_LIKE MSTL
628628
ON MSTL.meme_id = M.id
629629
630630
WHERE 1=1

src/recommendations/meme_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def generate_recommendations(user_id, limit):
8787

8888
r = random.random()
8989

90-
if r < 0.5:
90+
if r < 0.2:
9191
candidates = await get_fast_dopamine(
9292
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
9393
)

tests/stats/test_meme.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,59 +5,83 @@
55
from sqlalchemy import delete, insert, select
66
from sqlalchemy.ext.asyncio import AsyncConnection
77

8-
from src.database import (engine, fetch_all, meme, meme_source, meme_stats, user,
9-
user_language, user_meme_reaction)
8+
from src.database import (
9+
engine,
10+
fetch_all,
11+
meme,
12+
meme_source,
13+
meme_stats,
14+
user,
15+
user_language,
16+
user_meme_reaction,
17+
)
1018
from src.stats.meme import calculate_meme_reactions_stats
1119

1220

1321
@pytest_asyncio.fixture()
1422
async def conn():
1523
async with engine.connect() as conn:
16-
1724
await conn.execute(
18-
insert(user),
19-
[{'id': 1, 'type': "user"}, {'id': 2, 'type': "user"}]
25+
insert(user), [{"id": 1, "type": "user"}, {"id": 2, "type": "user"}]
2026
)
2127
await conn.execute(
2228
insert(meme_source),
23-
{'id': 1, 'type': 'telegram', 'url': '111', 'status': 'parsing_enabled', 'created_at': datetime(2024, 1, 1)}
29+
{
30+
"id": 1,
31+
"type": "telegram",
32+
"url": "111",
33+
"status": "parsing_enabled",
34+
"created_at": datetime(2024, 1, 1),
35+
},
2436
)
2537

2638
meme_common = {
27-
'type': 'image', 'telegram_image_id': '111', 'caption': '111', 'meme_source_id': 1,
28-
'published_at': datetime(2024, 1, 1), 'status': 'ok', 'language_code': 'ru',
39+
"type": "image",
40+
"telegram_image_id": "111",
41+
"caption": "111",
42+
"meme_source_id": 1,
43+
"published_at": datetime(2024, 1, 1),
44+
"status": "ok",
45+
"language_code": "ru",
2946
}
3047
meme_ids = [1, 2, 3, 4, 5, 6]
3148
await conn.execute(
3249
insert(meme),
33-
[{'id': meme_id, 'raw_meme_id': meme_id, **meme_common} for meme_id in meme_ids]
50+
[
51+
{"id": meme_id, "raw_meme_id": meme_id, **meme_common}
52+
for meme_id in meme_ids
53+
],
3454
)
3555

36-
u_common = {'language_code': 'ru', 'created_at': datetime(2024, 1, 1)}
56+
u_common = {"language_code": "ru", "created_at": datetime(2024, 1, 1)}
3757
await conn.execute(
3858
insert(user_language),
3959
[
40-
{'user_id': 1, **u_common},
41-
{'user_id': 2, **u_common},
42-
]
60+
{"user_id": 1, **u_common},
61+
{"user_id": 2, **u_common},
62+
],
4363
)
44-
umr_common = {'recommended_by': '111', 'sent_at': datetime(2024, 1, 1), 'reacted_at': datetime(2024, 1, 1, 0, 10)}
64+
umr_common = {
65+
"recommended_by": "111",
66+
"sent_at": datetime(2024, 1, 1),
67+
"reacted_at": datetime(2024, 1, 1, 0, 10),
68+
}
4569
await conn.execute(
4670
insert(user_meme_reaction),
4771
[
48-
{'user_id': 1, 'meme_id': 1, 'reaction_id': 1, **umr_common},
49-
{'user_id': 1, 'meme_id': 2, 'reaction_id': 1, **umr_common},
50-
{'user_id': 1, 'meme_id': 3, 'reaction_id': 1, **umr_common},
51-
{'user_id': 1, 'meme_id': 4, 'reaction_id': 1, **umr_common},
52-
{'user_id': 1, 'meme_id': 5, 'reaction_id': 1, **umr_common},
53-
{'user_id': 1, 'meme_id': 6, 'reaction_id': 2, **umr_common},
54-
{'user_id': 2, 'meme_id': 1, 'reaction_id': 1, **umr_common},
55-
{'user_id': 2, 'meme_id': 2, 'reaction_id': 2, **umr_common},
56-
{'user_id': 2, 'meme_id': 3, 'reaction_id': 2, **umr_common},
57-
{'user_id': 2, 'meme_id': 4, 'reaction_id': 2, **umr_common},
58-
{'user_id': 2, 'meme_id': 5, 'reaction_id': 2, **umr_common},
59-
{'user_id': 2, 'meme_id': 6, 'reaction_id': 2, **umr_common},
60-
]
72+
{"user_id": 1, "meme_id": 1, "reaction_id": 1, **umr_common},
73+
{"user_id": 1, "meme_id": 2, "reaction_id": 1, **umr_common},
74+
{"user_id": 1, "meme_id": 3, "reaction_id": 1, **umr_common},
75+
{"user_id": 1, "meme_id": 4, "reaction_id": 1, **umr_common},
76+
{"user_id": 1, "meme_id": 5, "reaction_id": 1, **umr_common},
77+
{"user_id": 1, "meme_id": 6, "reaction_id": 2, **umr_common},
78+
{"user_id": 2, "meme_id": 1, "reaction_id": 1, **umr_common},
79+
{"user_id": 2, "meme_id": 2, "reaction_id": 2, **umr_common},
80+
{"user_id": 2, "meme_id": 3, "reaction_id": 2, **umr_common},
81+
{"user_id": 2, "meme_id": 4, "reaction_id": 2, **umr_common},
82+
{"user_id": 2, "meme_id": 5, "reaction_id": 2, **umr_common},
83+
{"user_id": 2, "meme_id": 6, "reaction_id": 2, **umr_common},
84+
],
6185
)
6286

6387
await conn.commit()
@@ -72,7 +96,6 @@ async def conn():
7296
await conn.commit()
7397

7498

75-
7699
@pytest.mark.asyncio
77100
async def test_calculate_meme_reactions_stats(conn: AsyncConnection):
78101
await calculate_meme_reactions_stats(min_meme_reactions=0, min_user_reactions=0)
@@ -84,7 +107,7 @@ async def test_calculate_meme_reactions_stats(conn: AsyncConnection):
84107

85108
eps = 1e-3
86109
for row in res:
87-
if row['meme_id'] == 1:
88-
assert abs(row['lr_smoothed'] - 1) < eps
89-
if row['meme_id'] == 2:
90-
assert abs(row['lr_smoothed']) < eps
110+
if row["meme_id"] == 1:
111+
assert abs(row["lr_smoothed"] - 1) < eps
112+
if row["meme_id"] == 2:
113+
assert abs(row["lr_smoothed"]) < eps

0 commit comments

Comments
 (0)