@@ -29,6 +29,7 @@ def eval_muzero(main_cfg, create_cfg, seed=0):
29
29
from lzero .model .muzero_model_mlp import MuZeroModelMLP as Encoder
30
30
elif create_cfg .policy .type == 'efficientzero' :
31
31
from lzero .mcts import EfficientZeroGameBuffer as GameBuffer
32
+ from lzero .model .efficientzero_model_mlp import EfficientZeroModelMLP as Encoder
32
33
elif create_cfg .policy .type == 'sampled_efficientzero' :
33
34
from lzero .mcts import SampledEfficientZeroGameBuffer as GameBuffer
34
35
elif create_cfg .policy .type == 'gumbel_muzero' :
@@ -52,7 +53,8 @@ def eval_muzero(main_cfg, create_cfg, seed=0):
52
53
evaluator_env .seed (cfg .seed , dynamic_seed = False )
53
54
set_pkg_seed (cfg .seed , use_cuda = cfg .policy .cuda )
54
55
55
- model = Encoder (** cfg .policy .model , state_encoder = PettingZooEncoder (cfg ), state_prediction = PettingZooPrediction (cfg ), state_dynamics = PettingZooDynamics (cfg ))
56
+ # model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg))
57
+ model = Encoder (** cfg .policy .model , state_encoder = PettingZooEncoder (cfg ))
56
58
policy = create_policy (cfg .policy , model = model , enable_field = ['learn' , 'collect' , 'eval' ])
57
59
policy .eval_mode .load_state_dict (torch .load (cfg .policy .load_path , map_location = cfg .policy .device ))
58
60
@@ -78,5 +80,6 @@ def eval_muzero(main_cfg, create_cfg, seed=0):
78
80
return stop , reward
79
81
80
82
if __name__ == '__main__' :
81
- from zoo .petting_zoo .config .ptz_simple_mz_config import main_config , create_config
83
+ # from zoo.petting_zoo.config.ptz_simple_mz_config import main_config, create_config
84
+ from zoo .petting_zoo .config .ptz_simple_ez_config import main_config , create_config
82
85
eval_muzero (main_config , create_config , seed = 0 )
0 commit comments