Skip to content

Commit

Permalink
added generate_with_blender
Browse files Browse the repository at this point in the history
  • Loading branch information
ledovsky committed Sep 5, 2024
1 parent b51a763 commit 48a5c1e
Show file tree
Hide file tree
Showing 3 changed files with 391 additions and 5 deletions.
39 changes: 39 additions & 0 deletions src/recommendations/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,3 +643,42 @@ async def get_fast_dopamine(
"""
res = await fetch_all(text(query))
return res


class CandidatesRetriever:
"""CandidatesRetriever class is used for unit testing"""

engine_map = {
'less_seen_meme_and_source': less_seen_meme_and_source,
'uploaded_memes': uploaded_memes,
'fast_dopamine': get_fast_dopamine,
'lr_smoothed': get_lr_smoothed,
'selected_sources': get_selected_sources,
'best_memes_from_each_source': get_best_memes_from_each_source,
}

async def get_candidates(
self,
engine: str,
user_id: int,
limit: int = 10,
exclude_mem_ids: list[int] = []
) -> list[dict[str, Any]]:
if engine not in self.engine_map:
raise ValueError(f'engine {engine} is not supported')

return await self.engine_map[engine](user_id, limit, exclude_mem_ids)

async def get_candidates_dict(
self,
engines: list[str],
user_id: int,
limit: int = 10,
exclude_mem_ids: list[int] = []
) -> dict[str, list[dict[str, Any]]]:
candidates_dict = {}
for engine in engines:
candidates_dict[engine] = await self.get_candidates(
engine, user_id, limit, exclude_mem_ids)

return candidates_dict
109 changes: 104 additions & 5 deletions src/recommendations/meme_queue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
from typing import Any, Optional

from src import redis
from src.recommendations.blender import blend
from src.recommendations.candidates import (
get_best_memes_from_each_source,
get_fast_dopamine,
Expand All @@ -9,6 +11,7 @@
less_seen_meme_and_source,
like_spread_and_recent_memes,
uploaded_memes,
CandidatesRetriever,
)
from src.storage.schemas import MemeData
from src.tgbot.user_info import get_user_info
Expand Down Expand Up @@ -76,7 +79,12 @@ async def generate_cold_start_recommendations(user_id, limit=10):
await redis.add_memes_to_queue_by_key(queue_key, candidates)


async def generate_recommendations(user_id, limit):
async def generate_recommendations(user_id: int, limit: int):

if (user_id + 50) % 100 < 50:
generate_with_blender(user_id, limit)
return

queue_key = redis.get_meme_queue_key(user_id)
memes_in_queue = await redis.get_all_memes_in_queue_by_key(queue_key)
meme_ids_in_queue = [meme["id"] for meme in memes_in_queue]
Expand All @@ -103,8 +111,7 @@ async def generate_recommendations(user_id, limit):
)

elif user_info["nmemes_sent"] < 100:
if r < 0.2:
candidates = await uploaded_memes(
if r < 0.2: candidates = await uploaded_memes(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
elif r < 0.4:
Expand Down Expand Up @@ -141,8 +148,7 @@ async def generate_recommendations(user_id, limit):

if len(candidates) == 0 and user_info["nmemes_sent"] > 1000:
candidates = await less_seen_meme_and_source(
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue
)
user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue)

if len(candidates) == 0:
# TODO: fallback to some algo which will always return something
Expand All @@ -153,3 +159,96 @@ async def generate_recommendations(user_id, limit):
# select the best LIMIT memes -> save them to queue

await redis.add_memes_to_queue_by_key(queue_key, candidates)


async def generate_with_blender(
user_id: int,
limit: int,
nmemes_sent: Optional[int] = None,
retriever: Optional[CandidatesRetriever] = None,
random_seed: int = 42,
) -> list[dict[str, Any]]:
"""Uses blender to mix candidates from different engines
The function aims to keep the same logic as generate_candidates but
with blending.
Will be refactored
"""

if nmemes_sent is None:
user_info = await get_user_info(user_id)
nmemes_sent = user_info['nmemes_sent']

queue_key = redis.get_meme_queue_key(user_id)

meme_ids_in_queue = []
memes_in_queue = await redis.get_all_memes_in_queue_by_key(queue_key)
meme_ids_in_queue = [meme["id"] for meme in memes_in_queue]

if retriever is None:
retriever = CandidatesRetriever()

async def get_candidates(user_id, limit):
"""A helper function to avoid copy-paste"""

# <30 is treated as cold start. no blending
if nmemes_sent < 30:
candidates = await retriever.get_candidates(
'fast_dopamine', user_id, limit, exclude_mem_ids=meme_ids_in_queue)


if len(candidates) == 0:
candidates = await retriever.get_candidates(
'best_memes_from_each_source', user_id, limit, exclude_mem_ids=meme_ids_in_queue)

return candidates

if nmemes_sent < 100:
weights = {
'uploaded_memes': 0.2,
'fast_dopamine': 0.2,
'best_memes_from_each_source': 0.2,
'lr_smoothed': 0.4,
}

engines = ['uploaded_memes', 'fast_dopamine',
'best_memes_from_each_source', 'lr_smoothed']
candidates_dict = await retriever.get_candidates_dict(
engines, user_id, limit, exclude_mem_ids=meme_ids_in_queue)

fixed_pos = {0: 'lr_smoothed', 1: 'lr_smoothed'}
return blend(candidates_dict, weights, fixed_pos, limit, random_seed)

# >=100
weights = {
'uploaded_memes': 0.3,
'like_spread_and_recent_memes': 0.3,
'lr_smoothed': 0.4,
}

engines = ['uploaded_memes', 'like_spread_and_recent_memes', 'lr_smoothed']
candidates_dict = await retriever.get_candidates_dict(
engines, user_id, limit, exclude_mem_ids=meme_ids_in_queue)

fixed_pos = {0: 'lr_smoothed', 1: 'lr_smoothed'}
candidates = blend(candidates_dict, weights, fixed_pos, limit, random_seed)

if len(candidates) == 0 and nmemes_sent > 1000:
candidates = await retriever.get_candidates(
'less_seen_meme_and_source', user_id, limit,
exclude_mem_ids=meme_ids_in_queue)

if len(candidates) == 0:
candidates = await retriever.get_candidates(
'best_memes_from_each_source', user_id, limit,
exclude_mem_ids=meme_ids_in_queue)

return candidates

candidates = await get_candidates(user_id, limit)
if len(candidates) > 0:
await redis.add_memes_to_queue_by_key(queue_key, candidates)

return candidates

Loading

0 comments on commit 48a5c1e

Please sign in to comment.