Skip to content

Commit 1ebb875

Browse files
author
The android_world Authors
committed
Changing the argument type of terminate_fn to AsyncEnv so that the is_successful function of TaskEval can be used as terminate_fn.
PiperOrigin-RevId: 716368342
1 parent 4c966eb commit 1ebb875

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

android_world/episode_runner.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616

1717
import dataclasses
1818
from typing import Any, Callable, Optional
19-
20-
from android_env import env_interface
2119
from android_world import constants
2220
from android_world.agents import base_agent
21+
from android_world.env import interface
2322
import termcolor
2423

2524

@@ -45,9 +44,7 @@ def run_episode(
4544
agent: base_agent.EnvironmentInteractingAgent,
4645
max_n_steps: int = 10,
4746
start_on_home_screen: bool = False,
48-
termination_fn: (
49-
Callable[[env_interface.AndroidEnvInterface], float] | None
50-
) = None,
47+
termination_fn: Callable[[interface.AsyncEnv], float] | None = None,
5148
) -> EpisodeResult:
5249
"""Runs an agent on goal, e.g., "turn off wifi".
5350
@@ -83,7 +80,7 @@ def run_episode(
8380
print('Completed step {:d}.'.format(step_n + 1))
8481
assert constants.STEP_NUMBER not in result.data
8582
output.append(result.data | {constants.STEP_NUMBER: step_n})
86-
if termination_fn(agent.env.controller):
83+
if termination_fn(agent.env):
8784
print('Environment ends episode.')
8885
return EpisodeResult(
8986
done=True,

android_world/task_evals/miniwob/miniwob_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def get_episode_reward(env: env_interface.AndroidEnvInterface) -> float:
5959
return float(int(reward))
6060

6161

62-
def is_episode_terminated(env: env_interface.AndroidEnvInterface) -> bool:
62+
def is_episode_terminated(env: interface.AsyncEnv) -> bool:
6363
"""Checks if the current episode is terminated."""
64-
return get_episode_reward(env) != 0.0
64+
return get_episode_reward(env.controller.env) != 0.0
6565

6666

6767
class MiniWoBTask(task_eval.TaskEval, abc.ABC):

0 commit comments

Comments
 (0)