Skip to content

Commit c323a44

Browse files
author
jayyoung0802
committed
fix(yzj): fix ptz simple ez eval muzero
1 parent 0e6dfd3 commit c323a44

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

zoo/petting_zoo/entry/eval_muzero.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def eval_muzero(main_cfg, create_cfg, seed=0):
2929
from lzero.model.muzero_model_mlp import MuZeroModelMLP as Encoder
3030
elif create_cfg.policy.type == 'efficientzero':
3131
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
32+
from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP as Encoder
3233
elif create_cfg.policy.type == 'sampled_efficientzero':
3334
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
3435
elif create_cfg.policy.type == 'gumbel_muzero':
@@ -52,7 +53,8 @@ def eval_muzero(main_cfg, create_cfg, seed=0):
5253
evaluator_env.seed(cfg.seed, dynamic_seed=False)
5354
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
5455

55-
model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg))
56+
# model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg), state_prediction=PettingZooPrediction(cfg), state_dynamics=PettingZooDynamics(cfg))
57+
model = Encoder(**cfg.policy.model, state_encoder=PettingZooEncoder(cfg))
5658
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
5759
policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location=cfg.policy.device))
5860

@@ -78,5 +80,6 @@ def eval_muzero(main_cfg, create_cfg, seed=0):
7880
return stop, reward
7981

8082
if __name__ == '__main__':
81-
from zoo.petting_zoo.config.ptz_simple_mz_config import main_config, create_config
83+
# from zoo.petting_zoo.config.ptz_simple_mz_config import main_config, create_config
84+
from zoo.petting_zoo.config.ptz_simple_ez_config import main_config, create_config
8285
eval_muzero(main_config, create_config, seed=0)

0 commit comments

Comments
 (0)