From 48a5c1e91ad7523da2353d33517b0d8199b455d8 Mon Sep 17 00:00:00 2001 From: Alexander Ledovsky Date: Thu, 5 Sep 2024 22:17:44 +0300 Subject: [PATCH] added generate_with_blender --- src/recommendations/candidates.py | 39 ++++ src/recommendations/meme_queue.py | 109 +++++++++- tests/recommendations/test_meme_queue.py | 248 +++++++++++++++++++++++ 3 files changed, 391 insertions(+), 5 deletions(-) create mode 100644 tests/recommendations/test_meme_queue.py diff --git a/src/recommendations/candidates.py b/src/recommendations/candidates.py index 72d7fc8f..3476894c 100644 --- a/src/recommendations/candidates.py +++ b/src/recommendations/candidates.py @@ -643,3 +643,42 @@ async def get_fast_dopamine( """ res = await fetch_all(text(query)) return res + + +class CandidatesRetriever: + """CandidatesRetriever class is used for unit testing""" + + engine_map = { + 'less_seen_meme_and_source': less_seen_meme_and_source, + 'uploaded_memes': uploaded_memes, + 'fast_dopamine': get_fast_dopamine, + 'lr_smoothed': get_lr_smoothed, + 'selected_sources': get_selected_sources, + 'best_memes_from_each_source': get_best_memes_from_each_source, + } + + async def get_candidates( + self, + engine: str, + user_id: int, + limit: int = 10, + exclude_mem_ids: list[int] = [] + ) -> list[dict[str, Any]]: + if engine not in self.engine_map: + raise ValueError(f'engine {engine} is not supported') + + return await self.engine_map[engine](user_id, limit, exclude_mem_ids) + + async def get_candidates_dict( + self, + engines: list[str], + user_id: int, + limit: int = 10, + exclude_mem_ids: list[int] = [] + ) -> dict[str, list[dict[str, Any]]]: + candidates_dict = {} + for engine in engines: + candidates_dict[engine] = await self.get_candidates( + engine, user_id, limit, exclude_mem_ids) + + return candidates_dict \ No newline at end of file diff --git a/src/recommendations/meme_queue.py b/src/recommendations/meme_queue.py index 32b5711a..74463e5d 100644 --- a/src/recommendations/meme_queue.py +++ b/src/recommendations/meme_queue.py @@ -1,6 +1,8 @@ import random +from typing import Any, Optional from src import redis +from src.recommendations.blender import blend from src.recommendations.candidates import ( get_best_memes_from_each_source, get_fast_dopamine, @@ -9,6 +11,7 @@ less_seen_meme_and_source, like_spread_and_recent_memes, uploaded_memes, + CandidatesRetriever, ) from src.storage.schemas import MemeData from src.tgbot.user_info import get_user_info @@ -76,7 +79,12 @@ async def generate_cold_start_recommendations(user_id, limit=10): await redis.add_memes_to_queue_by_key(queue_key, candidates) -async def generate_recommendations(user_id, limit): +async def generate_recommendations(user_id: int, limit: int): + + if (user_id + 50) % 100 < 50: + generate_with_blender(user_id, limit) + return + queue_key = redis.get_meme_queue_key(user_id) memes_in_queue = await redis.get_all_memes_in_queue_by_key(queue_key) meme_ids_in_queue = [meme["id"] for meme in memes_in_queue] @@ -103,8 +111,7 @@ async def generate_recommendations(user_id, limit): ) elif user_info["nmemes_sent"] < 100: - if r < 0.2: - candidates = await uploaded_memes( + if r < 0.2: candidates = await uploaded_memes( user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue ) elif r < 0.4: @@ -141,8 +148,7 @@ async def generate_recommendations(user_id, limit): if len(candidates) == 0 and user_info["nmemes_sent"] > 1000: candidates = await less_seen_meme_and_source( - user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue - ) + user_id, limit=limit, exclude_meme_ids=meme_ids_in_queue) if len(candidates) == 0: # TODO: fallback to some algo which will always return something @@ -153,3 +159,96 @@ async def generate_recommendations(user_id, limit): # select the best LIMIT memes -> save them to queue await redis.add_memes_to_queue_by_key(queue_key, candidates) + + +async def generate_with_blender( + user_id: int, + limit: int, + nmemes_sent: Optional[int] = None, + retriever: Optional[CandidatesRetriever] = None, + random_seed: int = 42, + ) -> list[dict[str, Any]]: + """Uses blender to mix candidates from different engines + + The function aims to keep the same logic as generate_candidates but + with blending. + + Will be refactored + """ + + if nmemes_sent is None: + user_info = await get_user_info(user_id) + nmemes_sent = user_info['nmemes_sent'] + + queue_key = redis.get_meme_queue_key(user_id) + + meme_ids_in_queue = [] + memes_in_queue = await redis.get_all_memes_in_queue_by_key(queue_key) + meme_ids_in_queue = [meme["id"] for meme in memes_in_queue] + + if retriever is None: + retriever = CandidatesRetriever() + + async def get_candidates(user_id, limit): + """A helper function to avoid copy-paste""" + + # <30 is treated as cold start. no blending + if nmemes_sent < 30: + candidates = await retriever.get_candidates( + 'fast_dopamine', user_id, limit, exclude_mem_ids=meme_ids_in_queue) + + + if len(candidates) == 0: + candidates = await retriever.get_candidates( + 'best_memes_from_each_source', user_id, limit, exclude_mem_ids=meme_ids_in_queue) + + return candidates + + if nmemes_sent < 100: + weights = { + 'uploaded_memes': 0.2, + 'fast_dopamine': 0.2, + 'best_memes_from_each_source': 0.2, + 'lr_smoothed': 0.4, + } + + engines = ['uploaded_memes', 'fast_dopamine', + 'best_memes_from_each_source', 'lr_smoothed'] + candidates_dict = await retriever.get_candidates_dict( + engines, user_id, limit, exclude_mem_ids=meme_ids_in_queue) + + fixed_pos = {0: 'lr_smoothed', 1: 'lr_smoothed'} + return blend(candidates_dict, weights, fixed_pos, limit, random_seed) + + # >=100 + weights = { + 'uploaded_memes': 0.3, + 'like_spread_and_recent_memes': 0.3, + 'lr_smoothed': 0.4, + } + + engines = ['uploaded_memes', 'like_spread_and_recent_memes', 'lr_smoothed'] + candidates_dict = await retriever.get_candidates_dict( + engines, user_id, limit, exclude_mem_ids=meme_ids_in_queue) + + fixed_pos = {0: 'lr_smoothed', 1: 'lr_smoothed'} + candidates = blend(candidates_dict, weights, fixed_pos, limit, random_seed) + + if len(candidates) == 0 and nmemes_sent > 1000: + candidates = await retriever.get_candidates( + 'less_seen_meme_and_source', user_id, limit, + exclude_mem_ids=meme_ids_in_queue) + + if len(candidates) == 0: + candidates = await retriever.get_candidates( + 'best_memes_from_each_source', user_id, limit, + exclude_mem_ids=meme_ids_in_queue) + + return candidates + + candidates = await get_candidates(user_id, limit) + if len(candidates) > 0: + await redis.add_memes_to_queue_by_key(queue_key, candidates) + + return candidates + diff --git a/tests/recommendations/test_meme_queue.py b/tests/recommendations/test_meme_queue.py new file mode 100644 index 00000000..c5636f11 --- /dev/null +++ b/tests/recommendations/test_meme_queue.py @@ -0,0 +1,248 @@ +from typing import Any + +import pytest + +from src.recommendations.candidates import CandidatesRetriever +from src.recommendations.meme_queue import generate_with_blender + + +@pytest.mark.asyncio +async def test_generate_with_blender_below_30(): + async def get_fast_dopamine( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 1}, + {'id': 2}, + ] + + async def get_fast_dopamine_empty( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [] + + async def get_best_memes_from_each_source( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 3}, + {'id': 4}, + ] + + class TestRetriever(CandidatesRetriever): + engine_map = { + 'fast_dopamine': get_fast_dopamine, + 'best_meme_from_each_source': get_best_memes_from_each_source, + } + + candidates = await generate_with_blender(1, 10, 10, TestRetriever()) + assert len(candidates) == 2 + assert candidates[0]['id'] == 1 + assert candidates[1]['id'] == 2 + + class TestRetriever(CandidatesRetriever): + engine_map = { + 'fast_dopamine': get_fast_dopamine_empty, + 'best_memes_from_each_source': get_best_memes_from_each_source, + } + + candidates = await generate_with_blender(1, 10, 10, TestRetriever()) + assert len(candidates) == 2 + assert candidates[0]['id'] == 3 + assert candidates[1]['id'] == 4 + + +@pytest.mark.asyncio +async def test_generate_with_blender_below_100(): + async def uploaded_memes( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 1}, + {'id': 2}, + ] + async def get_fast_dopamine( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 3}, + {'id': 4}, + ] + async def get_best_memes_from_each_source( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 5}, + {'id': 6}, + ] + async def get_lr_smoothed( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 7}, + {'id': 8}, + {'id': 9}, + {'id': 10}, + ] + + + class TestRetriever(CandidatesRetriever): + engine_map = { + 'uploaded_memes': uploaded_memes, + 'fast_dopamine': get_fast_dopamine, + 'best_memes_from_each_source': get_best_memes_from_each_source, + 'lr_smoothed': get_lr_smoothed, + } + + candidates = await generate_with_blender(1, 10, 40, TestRetriever()) + assert len(candidates) == 10 + # hardcoded values + assert candidates[0]['id'] == 7 + assert candidates[1]['id'] == 8 + assert candidates[2]['id'] == 9 + assert candidates[3]['id'] == 1 + assert candidates[4]['id'] == 3 + assert candidates[5]['id'] == 4 + +@pytest.mark.asyncio +async def test_generate_with_blender_above_100(): + async def uploaded_memes( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 1}, + {'id': 2}, + {'id': 3}, + ] + async def like_spread_and_recent_memes( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 4}, + {'id': 5}, + {'id': 6}, + ] + async def get_lr_smoothed( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 7}, + {'id': 8}, + {'id': 9}, + {'id': 10}, + ] + + + class TestRetriever(CandidatesRetriever): + engine_map = { + 'uploaded_memes': uploaded_memes, + 'like_spread_and_recent_memes': like_spread_and_recent_memes, + 'lr_smoothed': get_lr_smoothed, + } + + candidates = await generate_with_blender(1, 10, 200, TestRetriever(), random_seed=102) + assert len(candidates) == 10 + # hardcoded values + assert candidates[0]['id'] == 7 + assert candidates[1]['id'] == 8 + assert candidates[2]['id'] == 1 + assert candidates[3]['id'] == 9 + assert candidates[4]['id'] == 2 + assert candidates[5]['id'] == 10 + + +@pytest.mark.asyncio +async def test_generate_with_blender_empty_above_100(): + async def uploaded_memes( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [] + + async def like_spread_and_recent_memes( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [] + + async def get_lr_smoothed( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [] + + async def top_memes_from_less_seen_sources( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 1}, + {'id': 2}, + ] + + async def get_best_memes_from_each_source( + self, + user_id: int, + limit: int = 10, + exclude_meme_ids: list[int] = [], + ) -> list[dict[str, Any]]: + return [ + {'id': 3}, + {'id': 4}, + ] + + + class TestRetriever(CandidatesRetriever): + engine_map = { + 'uploaded_memes': uploaded_memes, + 'like_spread_and_recent_memes': like_spread_and_recent_memes, + 'lr_smoothed': get_lr_smoothed, + 'less_seen_meme_and_source': top_memes_from_less_seen_sources, + 'best_memes_from_each_source': get_best_memes_from_each_source, + } + + candidates = await generate_with_blender(1, 10, 200, TestRetriever()) + assert len(candidates) == 2 + assert candidates[0]['id'] == 3 + + candidates = await generate_with_blender(1, 10, 1200, TestRetriever()) + assert len(candidates) == 2 + assert candidates[0]['id'] == 1 \ No newline at end of file