Skip to content

Commit 474bcbc

Browse files
committed
Add bootstrap for elo variance estimation
1 parent c990362 commit 474bcbc

File tree

1 file changed

+84
-39
lines changed

1 file changed

+84
-39
lines changed

kaggle_environments/envs/werewolf/eval/metrics.py

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def __init__(self, agent_name: str, openskill_model):
322322

323323
# Ratings
324324
self.elo: float = 1200.0
325+
self.elo_std: float = 0.0
325326
self.openskill_model = openskill_model
326327
self.openskill_rating = None
327328

@@ -387,6 +388,7 @@ def __init__(self, input_dir: Union[str, List[str]], gte_tasks: List[str] = None
387388
self.gte_game = None
388389
self.gte_joint = None
389390
self.gte_ratings = None
391+
self.gte_marginals = None
390392
self.gte_contributions_raw = None
391393

392394
if gte_tasks is None:
@@ -398,7 +400,66 @@ def __init__(self, input_dir: Union[str, List[str]], gte_tasks: List[str] = None
398400
else:
399401
self.gte_tasks = gte_tasks
400402

401-
def evaluate(self, gte_samples=3):
403+
def _compute_elo_ratings(self, games: List) -> Dict[str, float]:
404+
"""Computes Elo ratings for a given sequence of games."""
405+
elos = defaultdict(lambda: 1200.0)
406+
407+
for game in games:
408+
villager_agents = []
409+
werewolf_agents = []
410+
411+
for player in game.players:
412+
agent_name = player.agent.display_name
413+
if player.role.team == Team.VILLAGERS:
414+
villager_agents.append(agent_name)
415+
else:
416+
werewolf_agents.append(agent_name)
417+
418+
v_elos = [elos[a] for a in villager_agents]
419+
w_elos = [elos[a] for a in werewolf_agents]
420+
421+
if v_elos and w_elos:
422+
avg_v_elo = np.mean(v_elos)
423+
avg_w_elo = np.mean(w_elos)
424+
425+
result_v = 1 if game.winner_team == Team.VILLAGERS else 0
426+
427+
for agent in villager_agents:
428+
elos[agent] += calculate_elo_change(elos[agent], avg_w_elo, result_v)
429+
430+
for agent in werewolf_agents:
431+
elos[agent] += calculate_elo_change(elos[agent], avg_v_elo, 1 - result_v)
432+
433+
return elos
434+
435+
def _bootstrap_elo(self, num_samples=100):
436+
"""Estimates Elo standard error via bootstrapping."""
437+
if not self.games:
438+
return
439+
440+
rnd = np.random.default_rng(42)
441+
bootstrapped_elos = defaultdict(list)
442+
443+
# We need to know all agents to initialize lists, in case an agent isn't picked in a sample
444+
all_agents = list(self.metrics.keys())
445+
446+
for _ in range(num_samples):
447+
sampled_games = rnd.choice(self.games, size=len(self.games), replace=True)
448+
sample_elos = self._compute_elo_ratings(sampled_games)
449+
450+
for agent in all_agents:
451+
# If agent wasn't in the sample, they stay at 1200 (or we could skip,
452+
# but sticking to 1200 might bias if they rarely play.
453+
# Better to track only if they played, but for simplicity we assume 1200).
454+
# However, typically we only care about variance of active play.
455+
# Let's use the calculated value or 1200 default.
456+
bootstrapped_elos[agent].append(sample_elos.get(agent, 1200.0))
457+
458+
for agent, values in bootstrapped_elos.items():
459+
if len(values) > 1:
460+
self.metrics[agent].elo_std = float(np.std(values, ddof=1))
461+
462+
def evaluate(self, gte_samples=3, elo_samples=100):
402463
"""Processes all games and aggregates the metrics."""
403464
for game in self.games:
404465
# --- Win Rate & Survival Metrics ---
@@ -422,22 +483,25 @@ def evaluate(self, gte_samples=3):
422483
for agent_name, score in vss_results:
423484
self.metrics[agent_name].vss_scores.append(score)
424485

425-
# --- Rating Updates ---
426-
for game in self.games:
427-
villager_agents = []
428-
werewolf_agents = []
429-
430-
for player in game.players:
431-
agent_name = player.agent.display_name
432-
if player.role.team == Team.VILLAGERS:
433-
villager_agents.append(agent_name)
434-
else:
435-
werewolf_agents.append(agent_name)
436-
437-
# TrueSkill Update
438-
if OPENSKILL_AVAILABLE:
439-
# openskill expects [[r1, r2], [r3, r4]]
440-
# ranks=[0, 1] means first team won.
486+
# --- Rating Updates (Point Estimates) ---
487+
# 1. Elo
488+
final_elos = self._compute_elo_ratings(self.games)
489+
for agent, rating in final_elos.items():
490+
self.metrics[agent].elo = rating
491+
492+
# 2. TrueSkill (OpenSkill)
493+
# OpenSkill is order dependent too, but we just run it once sequentially here.
494+
if OPENSKILL_AVAILABLE and self.openskill_model:
495+
for game in self.games:
496+
villager_agents = []
497+
werewolf_agents = []
498+
for player in game.players:
499+
agent_name = player.agent.display_name
500+
if player.role.team == Team.VILLAGERS:
501+
villager_agents.append(agent_name)
502+
else:
503+
werewolf_agents.append(agent_name)
504+
441505
team_v = [self.metrics[a].openskill_rating for a in villager_agents]
442506
team_w = [self.metrics[a].openskill_rating for a in werewolf_agents]
443507

@@ -447,33 +511,14 @@ def evaluate(self, gte_samples=3):
447511
elif game.winner_team == Team.WEREWOLVES:
448512
teams = [team_w, team_v]
449513

450-
# If a team is empty (shouldn't happen in normal games but safety check), skip
451514
if teams:
452515
new_ratings = self.openskill_model.rate(teams)
453516
openskill_ratings = [rate for team in new_ratings for rate in team]
454517
for rating in openskill_ratings:
455518
self.metrics[rating.name].openskill_rating = rating
456519

457-
# Elo Update
458-
# Calculate team average Elo
459-
v_elos = [self.metrics[a].elo for a in villager_agents]
460-
w_elos = [self.metrics[a].elo for a in werewolf_agents]
461-
462-
if v_elos and w_elos:
463-
avg_v_elo = np.mean(v_elos)
464-
avg_w_elo = np.mean(w_elos)
465-
466-
# Result for Villagers
467-
result_v = 1 if game.winner_team == Team.VILLAGERS else 0
468-
469-
for agent in villager_agents:
470-
change = calculate_elo_change(self.metrics[agent].elo, avg_w_elo, result_v)
471-
self.metrics[agent].elo += change
472-
473-
for agent in werewolf_agents:
474-
change = calculate_elo_change(self.metrics[agent].elo, avg_v_elo, 1 - result_v)
475-
self.metrics[agent].elo += change
476-
520+
# --- Bootstrapping for Errors ---
521+
self._bootstrap_elo(num_samples=elo_samples)
477522
self._run_gte_evaluation(num_samples=gte_samples)
478523

479524
def _run_gte_evaluation(self, num_samples: int):
@@ -750,7 +795,7 @@ def _prepare_plot_data(self):
750795

751796
# Ratings
752797
plot_data.append(
753-
{'agent': agent_name, 'metric': 'Elo', 'value': metrics.elo, 'std': 0.0, 'category': 'Elo Rating'})
798+
{'agent': agent_name, 'metric': 'Elo', 'value': metrics.elo, 'std': metrics.elo_std, 'category': 'Elo Rating'})
754799
if OPENSKILL_AVAILABLE and metrics.openskill_rating:
755800
plot_data.append(
756801
{'agent': agent_name, 'metric': 'TrueSkill (mu)', 'value': metrics.openskill_rating.mu,

0 commit comments

Comments
 (0)