Skip to content

Commit 9635a07

Browse files
committed
Lazy loading of model weights
1 parent b9e837c commit 9635a07

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

axelrod/strategies/attention.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
DEVICES = torch.device("cpu")
2020

21-
model_weights = load_attention_model_weights()
22-
2321

2422
class GameState(IntEnum):
2523
CooperateDefect = 2
@@ -354,13 +352,20 @@ def __init__(
354352
self,
355353
) -> None:
356354
super().__init__()
357-
self.model = PlayerModel(PlayerConfig())
358-
self.model.load_state_dict(model_weights)
359-
self.model.to(DEVICES)
360-
self.model.eval()
355+
self.model = None
356+
357+
def load_model(self) -> None:
358+
"""Load the model weights."""
359+
if self.model is None:
360+
self.model = PlayerModel(PlayerConfig())
361+
self.model.load_state_dict(load_attention_model_weights())
362+
self.model.to(DEVICES)
363+
self.model.eval()
361364

362365
def strategy(self, opponent: Player) -> Action:
363366
"""Actual strategy definition that determines player's action."""
367+
# Load the model if not already loaded
368+
self.load_model()
364369
# Compute features
365370
features = compute_features(self, opponent).unsqueeze(0).to(DEVICES)
366371

axelrod/tests/strategies/test_attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Tests for the Attention strategies."""
22

33
import unittest
4+
from unittest.mock import patch
45

56
import torch
67

78
import axelrod as axl
9+
from axelrod.load_data_ import load_attention_model_weights
810
from axelrod.strategies.attention import (
911
MEMORY_LENGTH,
1012
GameState,
@@ -89,7 +91,21 @@ class TestEvolvedAttention(TestPlayer):
8991
def test_model_initialization(self):
9092
"""Test that the model is initialized correctly."""
9193
player = self.player()
92-
self.assertIsInstance(player.model, PlayerModel)
94+
self.assertIsNone(player.model)
95+
96+
def test_load_model(self):
97+
"""Test that the model can be loaded correctly."""
98+
with patch(
99+
"axelrod.strategies.attention.load_attention_model_weights",
100+
wraps=load_attention_model_weights,
101+
) as load_attention_model_weights_spy:
102+
player = self.player()
103+
self.assertIsNone(player.model)
104+
player.load_model()
105+
self.assertIsInstance(player.model, PlayerModel)
106+
player.load_model()
107+
self.assertIsInstance(player.model, PlayerModel)
108+
load_attention_model_weights_spy.assert_called_once()
93109

94110
def test_versus_cooperator(self):
95111
actions = [(C, C)] * 5

0 commit comments

Comments
 (0)