Skip to content

Commit b51a763

Browse files
committed
added blender
1 parent 72def94 commit b51a763

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

src/recommendations/blender.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import random
2+
from typing import Any
3+
4+
EPS = 1e-6
5+
6+
7+
def blend(
8+
candidates_dict: dict[str, list[dict[str, Any]]],
9+
weights_dict: dict[str, float],
10+
fixed_pos: dict[int, str] = None,
11+
limit: int = 0,
12+
random_seed: int = 42
13+
) -> list[dict[str, Any]]:
14+
"""
15+
Blends candidates from multiple recommendation engines. Blending is implemented
16+
as sampling with weights. Besides of that, it is possible to set fixed engines
17+
to some positions.
18+
19+
20+
Args:
21+
- candidates_dict: Contains recommendation engine names with their outputs
22+
Items in candidate lists must have "id" field
23+
- weights_dict: Contains weights for each engine. Should have the same keys
24+
as candidates_dict. Weights may not sum to 1
25+
- fixed_pos: Allows to set fixed engines to provided positions. Starts from 0
26+
- limit
27+
- random_seed
28+
"""
29+
30+
random.seed(random_seed)
31+
32+
# input validation and processing
33+
if set(candidates_dict.keys()) != set(weights_dict.keys()):
34+
raise ValueError('Keys in candidates_dict and weights_dict do not match')
35+
36+
if fixed_pos:
37+
for engine in fixed_pos.values():
38+
if engine not in candidates_dict:
39+
raise ValueError(f'Engine {engine} does not present in candidates_dict')
40+
41+
if limit == 0:
42+
for candidates in candidates_dict.values():
43+
limit += len(candidates)
44+
45+
# candidates_dict will be changed inplace further
46+
candidates_dict = candidates_dict.copy()
47+
for engine in candidates_dict.keys():
48+
candidates_dict[engine] = candidates_dict[engine].copy()
49+
50+
# engines list is ensured to have non-empty engines
51+
engines = [
52+
engine for engine in candidates_dict.keys()
53+
if len(candidates_dict[engine]) > 0
54+
]
55+
56+
weights = [ (weights_dict[engine] + EPS) for engine in engines]
57+
if len(engines) == 0:
58+
return []
59+
60+
res = []
61+
62+
for res_idx in range(limit):
63+
engine = None
64+
65+
# process fixed positions
66+
if fixed_pos and res_idx in fixed_pos:
67+
engine = fixed_pos[res_idx] if fixed_pos[res_idx] in engines else None
68+
69+
# sample engine
70+
if engine is None:
71+
engine = random.choices(population=engines, weights=weights)[0]
72+
73+
next_item = candidates_dict[engine][0].copy()
74+
res.append(next_item)
75+
76+
# process candidates intersection
77+
for engine in engines:
78+
# remove all matches with next_item
79+
stop = False
80+
while not stop:
81+
stop = True
82+
for idx in range(len(candidates_dict[engine])):
83+
if next_item['id'] == candidates_dict[engine][idx]['id']:
84+
candidates_dict[engine].pop(idx)
85+
stop = False
86+
break
87+
88+
# maintain non-empty engines
89+
engines = [engine for engine in engines if len(candidates_dict[engine]) > 0]
90+
weights = [ (weights_dict[engine] + EPS) for engine in engines ]
91+
92+
if len(engines) == 0:
93+
break
94+
95+
return res

tests/recommendations/test_blender.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from src.recommendations.blender import blend
2+
3+
4+
def test_blender_few_candidates():
5+
candidates_dict = {
6+
'engine_1': [
7+
{'id': 1},
8+
{'id': 2},
9+
],
10+
'engine_2': [
11+
{'id': 3},
12+
{'id': 4},
13+
],
14+
}
15+
weights_dict = {
16+
'engine_1': 1,
17+
'engine_2': 1,
18+
}
19+
20+
res = blend(candidates_dict, weights_dict)
21+
assert(len(res) == 4)
22+
23+
24+
def test_blender_item_intersection_and_zero_weight():
25+
26+
candidates_dict = {
27+
'engine_1': [
28+
{'id': 1},
29+
{'id': 2},
30+
{'id': 3},
31+
],
32+
'engine_2': [
33+
{'id': 3},
34+
{'id': 4},
35+
],
36+
}
37+
weights_dict = {
38+
'engine_1': 1,
39+
'engine_2': 0,
40+
}
41+
42+
res = blend(candidates_dict, weights_dict)
43+
assert(len(res) == 4)
44+
assert res[0]['id'] == 1
45+
assert res[1]['id'] == 2
46+
assert res[2]['id'] == 3
47+
assert res[3]['id'] == 4
48+
49+
50+
def test_blender_stats_test():
51+
candidates_dict = {
52+
'engine_1': [
53+
{'id': 1},
54+
{'id': 2},
55+
],
56+
'engine_2': [
57+
{'id': 3},
58+
{'id': 4},
59+
],
60+
}
61+
weights_dict = {
62+
'engine_1': 1,
63+
'engine_2': 3,
64+
}
65+
66+
engine_1_cnt = 0
67+
n_iter = 10000
68+
for i in range(n_iter):
69+
res = blend(candidates_dict, weights_dict, random_seed=(i + 10000))
70+
if res[0]['id'] == 1:
71+
engine_1_cnt += 1
72+
assert abs(engine_1_cnt / n_iter - 0.25) < 0.01
73+
74+
75+
def test_blender_fixed_pos():
76+
candidates_dict = {
77+
'engine_1': [
78+
{'id': 1},
79+
{'id': 2},
80+
],
81+
'engine_2': [
82+
{'id': 3},
83+
{'id': 4},
84+
],
85+
}
86+
weights_dict = {
87+
'engine_1': 0,
88+
'engine_2': 1,
89+
}
90+
91+
fixed_pos = {0: 'engine_1'}
92+
93+
res = blend(candidates_dict, weights_dict, fixed_pos=fixed_pos, random_seed=42)
94+
assert res[0]['id'] == 1
95+
96+
97+
def test_blender_limit():
98+
99+
candidates_dict = {
100+
'engine_1': [
101+
{'id': 1},
102+
{'id': 2},
103+
],
104+
'engine_2': [
105+
{'id': 3},
106+
{'id': 4},
107+
],
108+
}
109+
weights_dict = {
110+
'engine_1': 1,
111+
'engine_2': 1,
112+
}
113+
114+
res = blend(candidates_dict, weights_dict, limit=2)
115+
assert len(res) == 2

0 commit comments

Comments
 (0)