File tree Expand file tree Collapse file tree 2 files changed +28
-7
lines changed Expand file tree Collapse file tree 2 files changed +28
-7
lines changed Original file line number Diff line number Diff line change 18
18
19
19
DEVICES = torch .device ("cpu" )
20
20
21
- model_weights = load_attention_model_weights ()
22
-
23
21
24
22
class GameState (IntEnum ):
25
23
CooperateDefect = 2
@@ -354,13 +352,20 @@ def __init__(
354
352
self ,
355
353
) -> None :
356
354
super ().__init__ ()
357
- self .model = PlayerModel (PlayerConfig ())
358
- self .model .load_state_dict (model_weights )
359
- self .model .to (DEVICES )
360
- self .model .eval ()
355
+ self .model = None
356
+
357
+ def load_model (self ) -> None :
358
+ """Load the model weights."""
359
+ if self .model is None :
360
+ self .model = PlayerModel (PlayerConfig ())
361
+ self .model .load_state_dict (load_attention_model_weights ())
362
+ self .model .to (DEVICES )
363
+ self .model .eval ()
361
364
362
365
def strategy (self , opponent : Player ) -> Action :
363
366
"""Actual strategy definition that determines player's action."""
367
+ # Load the model if not already loaded
368
+ self .load_model ()
364
369
# Compute features
365
370
features = compute_features (self , opponent ).unsqueeze (0 ).to (DEVICES )
366
371
Original file line number Diff line number Diff line change 1
1
"""Tests for the Attention strategies."""
2
2
3
3
import unittest
4
+ from unittest .mock import patch
4
5
5
6
import torch
6
7
7
8
import axelrod as axl
9
+ from axelrod .load_data_ import load_attention_model_weights
8
10
from axelrod .strategies .attention import (
9
11
MEMORY_LENGTH ,
10
12
GameState ,
@@ -89,7 +91,21 @@ class TestEvolvedAttention(TestPlayer):
89
91
def test_model_initialization (self ):
90
92
"""Test that the model is initialized correctly."""
91
93
player = self .player ()
92
- self .assertIsInstance (player .model , PlayerModel )
94
+ self .assertIsNone (player .model )
95
+
96
+ def test_load_model (self ):
97
+ """Test that the model can be loaded correctly."""
98
+ with patch (
99
+ "axelrod.strategies.attention.load_attention_model_weights" ,
100
+ wraps = load_attention_model_weights ,
101
+ ) as load_attention_model_weights_spy :
102
+ player = self .player ()
103
+ self .assertIsNone (player .model )
104
+ player .load_model ()
105
+ self .assertIsInstance (player .model , PlayerModel )
106
+ player .load_model ()
107
+ self .assertIsInstance (player .model , PlayerModel )
108
+ load_attention_model_weights_spy .assert_called_once ()
93
109
94
110
def test_versus_cooperator (self ):
95
111
actions = [(C , C )] * 5
You can’t perform that action at this time.
0 commit comments