generated from zhanymkanov/fastapi_production_template
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import random | ||
from typing import Any | ||
|
||
EPS = 1e-6 | ||
|
||
|
||
def blend( | ||
candidates_dict: dict[str, list[dict[str, Any]]], | ||
weights_dict: dict[str, float], | ||
fixed_pos: dict[int, str] = None, | ||
limit: int = 0, | ||
random_seed: int = 42 | ||
) -> list[dict[str, Any]]: | ||
""" | ||
Blends candidates from multiple recommendation engines. Blending is implemented | ||
as sampling with weights. Besides of that, it is possible to set fixed engines | ||
to some positions. | ||
Args: | ||
- candidates_dict: Contains recommendation engine names with their outputs | ||
Items in candidate lists must have "id" field | ||
- weights_dict: Contains weights for each engine. Should have the same keys | ||
as candidates_dict. Weights may not sum to 1 | ||
- fixed_pos: Allows to set fixed engines to provided positions. Starts from 0 | ||
- limit | ||
- random_seed | ||
""" | ||
|
||
random.seed(random_seed) | ||
|
||
# input validation and processing | ||
if set(candidates_dict.keys()) != set(weights_dict.keys()): | ||
raise ValueError('Keys in candidates_dict and weights_dict do not match') | ||
|
||
if fixed_pos: | ||
for engine in fixed_pos.values(): | ||
if engine not in candidates_dict: | ||
raise ValueError(f'Engine {engine} does not present in candidates_dict') | ||
|
||
if limit == 0: | ||
for candidates in candidates_dict.values(): | ||
limit += len(candidates) | ||
|
||
# candidates_dict will be changed inplace further | ||
candidates_dict = candidates_dict.copy() | ||
for engine in candidates_dict.keys(): | ||
candidates_dict[engine] = candidates_dict[engine].copy() | ||
|
||
# engines list is ensured to have non-empty engines | ||
engines = [ | ||
engine for engine in candidates_dict.keys() | ||
if len(candidates_dict[engine]) > 0 | ||
] | ||
|
||
weights = [ (weights_dict[engine] + EPS) for engine in engines] | ||
if len(engines) == 0: | ||
return [] | ||
|
||
res = [] | ||
|
||
for res_idx in range(limit): | ||
engine = None | ||
|
||
# process fixed positions | ||
if fixed_pos and res_idx in fixed_pos: | ||
engine = fixed_pos[res_idx] if fixed_pos[res_idx] in engines else None | ||
|
||
# sample engine | ||
if engine is None: | ||
engine = random.choices(population=engines, weights=weights)[0] | ||
|
||
next_item = candidates_dict[engine][0].copy() | ||
res.append(next_item) | ||
|
||
# process candidates intersection | ||
for engine in engines: | ||
# remove all matches with next_item | ||
stop = False | ||
while not stop: | ||
stop = True | ||
for idx in range(len(candidates_dict[engine])): | ||
if next_item['id'] == candidates_dict[engine][idx]['id']: | ||
candidates_dict[engine].pop(idx) | ||
stop = False | ||
break | ||
|
||
# maintain non-empty engines | ||
engines = [engine for engine in engines if len(candidates_dict[engine]) > 0] | ||
weights = [ (weights_dict[engine] + EPS) for engine in engines ] | ||
|
||
if len(engines) == 0: | ||
break | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from src.recommendations.blender import blend | ||
|
||
|
||
def test_blender_few_candidates(): | ||
candidates_dict = { | ||
'engine_1': [ | ||
{'id': 1}, | ||
{'id': 2}, | ||
], | ||
'engine_2': [ | ||
{'id': 3}, | ||
{'id': 4}, | ||
], | ||
} | ||
weights_dict = { | ||
'engine_1': 1, | ||
'engine_2': 1, | ||
} | ||
|
||
res = blend(candidates_dict, weights_dict) | ||
assert(len(res) == 4) | ||
|
||
|
||
def test_blender_item_intersection_and_zero_weight(): | ||
|
||
candidates_dict = { | ||
'engine_1': [ | ||
{'id': 1}, | ||
{'id': 2}, | ||
{'id': 3}, | ||
], | ||
'engine_2': [ | ||
{'id': 3}, | ||
{'id': 4}, | ||
], | ||
} | ||
weights_dict = { | ||
'engine_1': 1, | ||
'engine_2': 0, | ||
} | ||
|
||
res = blend(candidates_dict, weights_dict) | ||
assert(len(res) == 4) | ||
assert res[0]['id'] == 1 | ||
assert res[1]['id'] == 2 | ||
assert res[2]['id'] == 3 | ||
assert res[3]['id'] == 4 | ||
|
||
|
||
def test_blender_stats_test(): | ||
candidates_dict = { | ||
'engine_1': [ | ||
{'id': 1}, | ||
{'id': 2}, | ||
], | ||
'engine_2': [ | ||
{'id': 3}, | ||
{'id': 4}, | ||
], | ||
} | ||
weights_dict = { | ||
'engine_1': 1, | ||
'engine_2': 3, | ||
} | ||
|
||
engine_1_cnt = 0 | ||
n_iter = 10000 | ||
for i in range(n_iter): | ||
res = blend(candidates_dict, weights_dict, random_seed=(i + 10000)) | ||
if res[0]['id'] == 1: | ||
engine_1_cnt += 1 | ||
assert abs(engine_1_cnt / n_iter - 0.25) < 0.01 | ||
|
||
|
||
def test_blender_fixed_pos(): | ||
candidates_dict = { | ||
'engine_1': [ | ||
{'id': 1}, | ||
{'id': 2}, | ||
], | ||
'engine_2': [ | ||
{'id': 3}, | ||
{'id': 4}, | ||
], | ||
} | ||
weights_dict = { | ||
'engine_1': 0, | ||
'engine_2': 1, | ||
} | ||
|
||
fixed_pos = {0: 'engine_1'} | ||
|
||
res = blend(candidates_dict, weights_dict, fixed_pos=fixed_pos, random_seed=42) | ||
assert res[0]['id'] == 1 | ||
|
||
|
||
def test_blender_limit(): | ||
|
||
candidates_dict = { | ||
'engine_1': [ | ||
{'id': 1}, | ||
{'id': 2}, | ||
], | ||
'engine_2': [ | ||
{'id': 3}, | ||
{'id': 4}, | ||
], | ||
} | ||
weights_dict = { | ||
'engine_1': 1, | ||
'engine_2': 1, | ||
} | ||
|
||
res = blend(candidates_dict, weights_dict, limit=2) | ||
assert len(res) == 2 |