@@ -24,15 +24,18 @@ def setup_env(self, n_envs, **kwargs) -> None:
24
24
)
25
25
self .env .seed (0 )
26
26
27
- @pytest .mark .skipif (not sys .platform .startswith ("linux" ), reason = "Test needs linux" )
28
- def test_ai_vs_random (self , n_envs = 4 , n_agents = 3 ):
27
+ @pytest .mark .skipif (
28
+ sys .platform .startswith ("win32" ), reason = "Test does not work on windows"
29
+ )
30
+ def test_ai_vs_random (self , n_envs = 4 , n_agents = 3 , scoring_reward = 1 ):
29
31
self .setup_env (
30
32
n_red_agents = n_agents ,
31
33
n_blue_agents = n_agents ,
32
34
ai_red_agents = True ,
33
35
ai_blue_agents = False ,
34
- dense_reward_ratio = 0 ,
36
+ dense_reward = False ,
35
37
n_envs = n_envs ,
38
+ scoring_reward = scoring_reward ,
36
39
)
37
40
all_done = torch .full ((n_envs ,), False )
38
41
obs = self .env .reset ()
@@ -49,10 +52,10 @@ def test_ai_vs_random(self, n_envs=4, n_agents=3):
49
52
total_rew [:, i ] += rews [i ]
50
53
if dones .any ():
51
54
# Done envs should have exactly sum of rewards equal to num_agents
52
- actual_rew = - 1 * n_agents
55
+ actual_rew = - scoring_reward * n_agents
53
56
assert torch .equal (
54
57
total_rew [dones ].sum (- 1 ).to (torch .long ),
55
- torch .full ((dones .sum (),), actual_rew ),
58
+ torch .full ((dones .sum (),), actual_rew , dtype = torch . long ),
56
59
)
57
60
total_rew [dones ] = 0
58
61
all_done += dones
0 commit comments