diff --git a/android_world/task_evals/task_eval.py b/android_world/task_evals/task_eval.py index fd80d41e..c08b28cf 100644 --- a/android_world/task_evals/task_eval.py +++ b/android_world/task_evals/task_eval.py @@ -25,7 +25,16 @@ from android_world.env.setup_device import setup from android_world.utils import app_snapshot from android_world.utils import datetime_utils +from android_world.env.setup_device import setup, apps +RESET_APPS = { + "audio recorder": apps.AudioRecorder, + "camera": apps.CameraApp, + "chrome": apps.ChromeApp, + "markor": apps.MarkorApp, + "simple calendar pro": apps.SimpleCalendarProApp, + "tasks": apps.TasksApp, +} class TaskEval(abc.ABC): """Interface for a task and its evaluation. @@ -124,6 +133,10 @@ def _initialize_apps(self, env: interface.AsyncEnv) -> None: except RuntimeError as error: logging.warning("Skipping app snapshot loading : %s", error) + if app_name in RESET_APPS: + logging.info("Reset app for %s", app_name) + setup.setup_app(RESET_APPS[app_name], env) + def install_apps_if_not_installed(self, env: interface.AsyncEnv) -> None: for app_name in self.app_names: setup.install_app_if_not_installed(app_name, env)