From 83f590e4a84a8ae0fda880490621c4a0df10075c Mon Sep 17 00:00:00 2001 From: Bartosz Prusak Date: Sun, 26 Nov 2023 22:57:58 +0100 Subject: [PATCH 1/5] Organize imports --- baselines/red_gym_env.py | 21 ++++++++++----------- baselines/run_baseline_parallel_fast.py | 13 +++++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 1b667ff9c..9f4502488 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -1,23 +1,22 @@ -import sys -import uuid -import os -from math import floor, sqrt import json +import sys +import uuid +from math import floor from pathlib import Path -import numpy as np -from einops import rearrange -import matplotlib.pyplot as plt -from skimage.transform import resize -from pyboy import PyBoy -from pyboy.logger import log_level import hnswlib +import matplotlib.pyplot as plt import mediapy as media +import numpy as np import pandas as pd - +from einops import rearrange from gymnasium import Env, spaces +from pyboy import PyBoy +from pyboy.logger import log_level from pyboy.utils import WindowEvent +from skimage.transform import resize + class RedGymEnv(Env): diff --git a/baselines/run_baseline_parallel_fast.py b/baselines/run_baseline_parallel_fast.py index de82739d9..78f6ec95e 100644 --- a/baselines/run_baseline_parallel_fast.py +++ b/baselines/run_baseline_parallel_fast.py @@ -1,14 +1,16 @@ +import uuid from os.path import exists from pathlib import Path -import uuid -from red_gym_env import RedGymEnv + from stable_baselines3 import PPO -from stable_baselines3.common import env_checker -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import SubprocVecEnv + +from red_gym_env import RedGymEnv from tensorboard_callback import TensorboardCallback + def make_env(rank, env_conf, seed=0): """ Utility function for multiprocessed env. @@ -63,7 +65,6 @@ def _init(): ) callbacks.append(WandbCallback()) - #env_checker.check_env(env) learn_steps = 40 # put a checkpoint here you want to start from file_name = 'session_e41c9eff/poke_38207488_steps' From ba9ac05839f13dbff225596071a2249803e906c7 Mon Sep 17 00:00:00 2001 From: Bartosz Prusak Date: Mon, 27 Nov 2023 00:35:17 +0100 Subject: [PATCH 2/5] Style changes and typing, typing everywhere * reorganized imports - automatically changed order of imports in some files and removed unused, * added explicit typing to all fields and methods in RedGymEnv (and all files that use this class) * as almost all methods and fields in RedGymEnv are private, they were renamed to start with _underscore Warning: not tested against older versions of python 3 --- baselines/red_gym_env.py | 641 +++++++++++++----------- baselines/render_all_needed_grids.py | 31 +- baselines/run_baseline_parallel.py | 27 +- baselines/run_baseline_parallel_fast.py | 17 +- baselines/run_pretrained_interactive.py | 30 +- baselines/run_recorded_actions.py | 2 +- 6 files changed, 407 insertions(+), 341 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 9f4502488..6c393d66e 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -12,49 +12,87 @@ import pandas as pd from einops import rearrange from gymnasium import Env, spaces +from numpy.typing import NDArray from pyboy import PyBoy from pyboy.logger import log_level from pyboy.utils import WindowEvent from skimage.transform import resize +from typing import Any, Optional, TypedDict + + +class RedGymEnvConfig(TypedDict): + debug: bool + session_path: Path + save_final_state: bool + print_rewards: bool + headless: bool + init_state: str + action_freq: int + max_steps: int + early_stop: bool + save_video: bool + fast_video: bool + explore_weight: Optional[float] + use_screen_explore: Optional[bool] + sim_frame_dist: float + reward_scale: Optional[float] + extra_buttons: Optional[bool] + instance_id: Optional[str] + + +class _AgentStats(TypedDict): + step: int + x: int + y: int + map: int + map_location: str + last_action: int + pcount: int + levels: list[int] + levels_sum: int + ptypes: int + hp: float + frames: Optional[int] + coord_count: Optional[int] + deaths: int + badge: int + event: float + healr: float class RedGymEnv(Env): - - def __init__( - self, config=None): - - self.debug = config['debug'] - self.s_path = config['session_path'] - self.save_final_state = config['save_final_state'] - self.print_rewards = config['print_rewards'] - self.vec_dim = 4320 #1000 - self.headless = config['headless'] - self.num_elements = 20000 # max - self.init_state = config['init_state'] - self.act_freq = config['action_freq'] - self.max_steps = config['max_steps'] - self.early_stopping = config['early_stop'] - self.save_video = config['save_video'] - self.fast_video = config['fast_video'] - self.video_interval = 256 * self.act_freq - self.downsample_factor = 2 - self.frame_stacks = 3 - self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight'] - self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] - self.similar_frame_dist = config['sim_frame_dist'] - self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] - self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] - self.instance_id = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] - self.s_path.mkdir(exist_ok=True) - self.reset_count = 0 - self.all_runs = [] + def __init__(self, config: Optional[RedGymEnvConfig] = None): + + self._debug: bool = config['debug'] # unused + self._s_path: Path = config['session_path'] + self._save_final_state: bool = config['save_final_state'] + self._print_rewards: bool = config['print_rewards'] + self._vec_dim: int = 4320 #1000 + self._headless: bool = config['headless'] + self._num_elements: int = 20000 # max + self._init_state: str = config['init_state'] + self._act_freq: int = config['action_freq'] + self._max_steps: int = config['max_steps'] + self._early_stopping: bool = config['early_stop'] + self._save_video: bool = config['save_video'] + self._fast_video: bool = config['fast_video'] + self._frame_stacks: int = 3 + self._explore_weight: float = 1 if 'explore_weight' not in config else config['explore_weight'] + self._use_screen_explore: bool = True if 'use_screen_explore' not in config else config['use_screen_explore'] + self._similar_frame_dist: float = config['sim_frame_dist'] + self._reward_scale: float = 1 if 'reward_scale' not in config else config['reward_scale'] + self._extra_buttons: bool = False if 'extra_buttons' not in config else config['extra_buttons'] + self._instance_id: str = str(uuid.uuid4())[:8] if 'instance_id' not in config else config['instance_id'] + self._s_path.mkdir(exist_ok=True) + self._reset_count: int = 0 + self._all_runs: list[dict[str, float]] = [] # Set this in SOME subclasses - self.metadata = {"render.modes": []} - self.reward_range = (0, 15000) + self.metadata: dict[str, Any] = {"render.modes": []} + self.reward_range: tuple[int, int] = (0, 15000) - self.valid_actions = [ + self.valid_actions: list[WindowEvent] = [ WindowEvent.PRESS_ARROW_DOWN, WindowEvent.PRESS_ARROW_LEFT, WindowEvent.PRESS_ARROW_RIGHT, @@ -63,42 +101,42 @@ def __init__( WindowEvent.PRESS_BUTTON_B, ] - if self.extra_buttons: + if self._extra_buttons: self.valid_actions.extend([ WindowEvent.PRESS_BUTTON_START, WindowEvent.PASS ]) - self.release_arrow = [ + self._release_arrow: list[WindowEvent] = [ WindowEvent.RELEASE_ARROW_DOWN, WindowEvent.RELEASE_ARROW_LEFT, WindowEvent.RELEASE_ARROW_RIGHT, WindowEvent.RELEASE_ARROW_UP ] - self.release_button = [ + self._release_button: list[WindowEvent] = [ WindowEvent.RELEASE_BUTTON_A, WindowEvent.RELEASE_BUTTON_B ] - self.output_shape = (36, 40, 3) - self.mem_padding = 2 - self.memory_height = 8 - self.col_steps = 16 - self.output_full = ( - self.output_shape[0] * self.frame_stacks + 2 * (self.mem_padding + self.memory_height), - self.output_shape[1], - self.output_shape[2] + self._output_shape: tuple[int, int, int] = (36, 40, 3) + self._mem_padding: int = 2 + self._memory_height: int = 8 + self._col_steps: int = 16 + self._output_full: tuple[int, int, int] = ( + self._output_shape[0] * self._frame_stacks + 2 * (self._mem_padding + self._memory_height), + self._output_shape[1], + self._output_shape[2] ) # Set these in ALL subclasses self.action_space = spaces.Discrete(len(self.valid_actions)) - self.observation_space = spaces.Box(low=0, high=255, shape=self.output_full, dtype=np.uint8) + self.observation_space = spaces.Box(low=0, high=255, shape=self._output_full, dtype=np.uint8) head = 'headless' if config['headless'] else 'SDL2' log_level("ERROR") - self.pyboy = PyBoy( + self._pyboy = PyBoy( config['gb_path'], debugging=False, disable_input=False, @@ -106,247 +144,273 @@ def __init__( hide_window='--quiet' in sys.argv, ) - self.screen = self.pyboy.botsupport_manager().screen() + self._screen = self._pyboy.botsupport_manager().screen() if not config['headless']: - self.pyboy.set_emulation_speed(6) + self._pyboy.set_emulation_speed(6) + + # Fields set in reset() + self._recent_memory: NDArray = np.zeros((1,), dtype=np.uint8) + self._recent_frames: NDArray = np.zeros((1,), dtype=np.uint8) + self._agent_stats: list[_AgentStats] = [] + self._full_frame_writer: Optional[media.VideoWriter] = None + self._model_frame_writer: Optional[media.VideoWriter] = None + self._levels_satisfied: bool = False + self._base_explore: int = 0 + self._max_opponent_level: int = 0 + self._max_event_rew: int = 0 + self._max_level_rew: float = 0. + self._last_health: float = 0. + self._total_healing_rew: float = 0. + self._died_count: int = 0 + self._party_size: int = 0 + self._step_count: int = 0 + self._progress_reward: dict[str, float] = {} + self._total_reward: float = 0. + self._seen_coords: dict[str, int] = {} + self._knn_index: Optional[hnswlib.Index] = None self.reset() - def reset(self, seed=None): - self.seed = seed + def reset(self, *, + seed: int | None = None, + options: dict[str, Any] | None = None, ) -> tuple[NDArray, dict[str, Any]]: # restart game, skipping credits - with open(self.init_state, "rb") as f: - self.pyboy.load_state(f) + with open(self._init_state, "rb") as f: + self._pyboy.load_state(f) - if self.use_screen_explore: - self.init_knn() + if self._use_screen_explore: + self._init_knn() else: - self.init_map_mem() + self._init_map_mem() - self.recent_memory = np.zeros((self.output_shape[1]*self.memory_height, 3), dtype=np.uint8) + self._recent_memory = np.zeros((self._output_shape[1] * self._memory_height, 3), dtype=np.uint8) - self.recent_frames = np.zeros( - (self.frame_stacks, self.output_shape[0], - self.output_shape[1], self.output_shape[2]), + self._recent_frames = np.zeros( + (self._frame_stacks, self._output_shape[0], + self._output_shape[1], self._output_shape[2]), dtype=np.uint8) - self.agent_stats = [] + self._agent_stats = [] - if self.save_video: - base_dir = self.s_path / Path('rollouts') + if self._save_video: + base_dir = self._s_path / Path('rollouts') base_dir.mkdir(exist_ok=True) - full_name = Path(f'full_reset_{self.reset_count}_id{self.instance_id}').with_suffix('.mp4') - model_name = Path(f'model_reset_{self.reset_count}_id{self.instance_id}').with_suffix('.mp4') - self.full_frame_writer = media.VideoWriter(base_dir / full_name, (144, 160), fps=60) - self.full_frame_writer.__enter__() - self.model_frame_writer = media.VideoWriter(base_dir / model_name, self.output_full[:2], fps=60) - self.model_frame_writer.__enter__() + full_name = Path(f'full_reset_{self._reset_count}_id{self._instance_id}').with_suffix('.mp4') + model_name = Path(f'model_reset_{self._reset_count}_id{self._instance_id}').with_suffix('.mp4') + self._full_frame_writer = media.VideoWriter(base_dir / full_name, (144, 160), fps=60) + self._full_frame_writer.__enter__() + self._model_frame_writer = media.VideoWriter(base_dir / model_name, self._output_full[:2], fps=60) + self._model_frame_writer.__enter__() - self.levels_satisfied = False - self.base_explore = 0 - self.max_opponent_level = 0 - self.max_event_rew = 0 - self.max_level_rew = 0 - self.last_health = 1 - self.total_healing_rew = 0 - self.died_count = 0 - self.party_size = 0 - self.step_count = 0 - self.progress_reward = self.get_game_state_reward() - self.total_reward = sum([val for _, val in self.progress_reward.items()]) - self.reset_count += 1 + self._levels_satisfied = False + self._base_explore = 0 + self._max_opponent_level = 0 + self._max_event_rew = 0 + self._max_level_rew = 0 + self._last_health = 1 + self._total_healing_rew = 0 + self._died_count = 0 + self._party_size = 0 + self._step_count = 0 + self._progress_reward: dict[str, float] = self._get_game_state_reward() + self._total_reward = sum([val for _, val in self._progress_reward.items()]) + self._reset_count += 1 return self.render(), {} - def init_knn(self): + def _init_knn(self) -> None: # Declaring index - self.knn_index = hnswlib.Index(space='l2', dim=self.vec_dim) # possible options are l2, cosine or ip + self._knn_index = hnswlib.Index(space='l2', dim=self._vec_dim) # possible options are l2, cosine or ip # Initing index - the maximum number of elements should be known beforehand - self.knn_index.init_index( - max_elements=self.num_elements, ef_construction=100, M=16) + self._knn_index.init_index( + max_elements=self._num_elements, ef_construction=100, M=16) - def init_map_mem(self): - self.seen_coords = {} + def _init_map_mem(self): + self._seen_coords = {} - def render(self, reduce_res=True, add_memory=True, update_mem=True): - game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) + def render(self, reduce_res: bool = True, add_memory: bool = True, update_mem: bool = True) -> NDArray: + game_pixels_render = self._screen.screen_ndarray() # (144, 160, 3) if reduce_res: - game_pixels_render = (255*resize(game_pixels_render, self.output_shape)).astype(np.uint8) + game_pixels_render = (255*resize(game_pixels_render, self._output_shape)).astype(np.uint8) if update_mem: - self.recent_frames[0] = game_pixels_render + self._recent_frames[0] = game_pixels_render if add_memory: pad = np.zeros( - shape=(self.mem_padding, self.output_shape[1], 3), + shape=(self._mem_padding, self._output_shape[1], 3), dtype=np.uint8) game_pixels_render = np.concatenate( ( - self.create_exploration_memory(), + self._create_exploration_memory(), pad, - self.create_recent_memory(), + self._create_recent_memory(), pad, - rearrange(self.recent_frames, 'f h w c -> (f h) w c') + rearrange(self._recent_frames, 'f h w c -> (f h) w c') ), axis=0) return game_pixels_render - def step(self, action): + def step(self, action: int) -> tuple[NDArray, float, bool, bool, dict[str, Any]]: - self.run_action_on_emulator(action) - self.append_agent_stats(action) + self._run_action_on_emulator(action) + self._append_agent_stats(action) - self.recent_frames = np.roll(self.recent_frames, 1, axis=0) + self._recent_frames = np.roll(self._recent_frames, 1, axis=0) obs_memory = self.render() # trim off memory from frame for knn index - frame_start = 2 * (self.memory_height + self.mem_padding) + frame_start = 2 * (self._memory_height + self._mem_padding) obs_flat = obs_memory[ - frame_start:frame_start+self.output_shape[0], ...].flatten().astype(np.float32) + frame_start:frame_start+self._output_shape[0], ...].flatten().astype(np.float32) - if self.use_screen_explore: - self.update_frame_knn_index(obs_flat) + if self._use_screen_explore: + self._update_frame_knn_index(obs_flat) else: - self.update_seen_coords() + self._update_seen_coords() - self.update_heal_reward() - self.party_size = self.read_m(0xD163) + self._update_heal_reward() + self._party_size = self._read_m(0xD163) - new_reward, new_prog = self.update_reward() + new_reward, new_prog = self._update_reward() - self.last_health = self.read_hp_fraction() + self._last_health = self._read_hp_fraction() # shift over short term reward memory - self.recent_memory = np.roll(self.recent_memory, 3) - self.recent_memory[0, 0] = min(new_prog[0] * 64, 255) - self.recent_memory[0, 1] = min(new_prog[1] * 64, 255) - self.recent_memory[0, 2] = min(new_prog[2] * 128, 255) + self._recent_memory = np.roll(self._recent_memory, 3) + self._recent_memory[0, 0] = min(new_prog[0] * 64, 255) + self._recent_memory[0, 1] = min(new_prog[1] * 64, 255) + self._recent_memory[0, 2] = min(new_prog[2] * 128, 255) - step_limit_reached = self.check_if_done() + step_limit_reached = self._check_if_done() - self.save_and_print_info(step_limit_reached, obs_memory) + self._save_and_print_info(step_limit_reached, obs_memory) - self.step_count += 1 + self._step_count += 1 return obs_memory, new_reward*0.1, False, step_limit_reached, {} - def run_action_on_emulator(self, action): + def _run_action_on_emulator(self, action: int) -> None: # press button then release after some steps - self.pyboy.send_input(self.valid_actions[action]) + self._pyboy.send_input(self.valid_actions[action]) # disable rendering when we don't need it - if not self.save_video and self.headless: - self.pyboy._rendering(False) - for i in range(self.act_freq): + if not self._save_video and self._headless: + self._pyboy._rendering(False) + for i in range(self._act_freq): # release action, so they are stateless if i == 8: if action < 4: # release arrow - self.pyboy.send_input(self.release_arrow[action]) + self._pyboy.send_input(self._release_arrow[action]) if action > 3 and action < 6: # release button - self.pyboy.send_input(self.release_button[action - 4]) + self._pyboy.send_input(self._release_button[action - 4]) if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START: - self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) - if self.save_video and not self.fast_video: - self.add_video_frame() - if i == self.act_freq-1: - self.pyboy._rendering(True) - self.pyboy.tick() - if self.save_video and self.fast_video: - self.add_video_frame() + self._pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) + if self._save_video and not self._fast_video: + self._add_video_frame() + if i == self._act_freq-1: + self._pyboy._rendering(True) + self._pyboy.tick() + if self._save_video and self._fast_video: + self._add_video_frame() - def add_video_frame(self): - self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) - self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) + def _add_video_frame(self) -> None: + self._full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) + self._model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) - def append_agent_stats(self, action): - x_pos = self.read_m(0xD362) - y_pos = self.read_m(0xD361) - map_n = self.read_m(0xD35E) - levels = [self.read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] - if self.use_screen_explore: - expl = ('frames', self.knn_index.get_current_count()) + def _append_agent_stats(self, action: int) -> None: + x_pos = self._read_m(0xD362) + y_pos = self._read_m(0xD361) + map_n = self._read_m(0xD35E) + levels = [self._read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] + if self._use_screen_explore: + expl = ('frames', self._knn_index.get_current_count()) else: - expl = ('coord_count', len(self.seen_coords)) - self.agent_stats.append({ - 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, - 'map_location': self.get_map_location(map_n), + expl = ('coord_count', len(self._seen_coords)) + self._agent_stats.append({ + 'step': self._step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, + 'map_location': self._get_map_location(map_n), 'last_action': action, - 'pcount': self.read_m(0xD163), + 'pcount': self._read_m(0xD163), 'levels': levels, 'levels_sum': sum(levels), - 'ptypes': self.read_party(), - 'hp': self.read_hp_fraction(), + 'ptypes': self._read_party(), + 'hp': self._read_hp_fraction(), expl[0]: expl[1], - 'deaths': self.died_count, 'badge': self.get_badges(), - 'event': self.progress_reward['event'], 'healr': self.total_healing_rew + 'deaths': self._died_count, 'badge': self._get_badges(), + 'event': self._progress_reward['event'], 'healr': self._total_healing_rew }) - def update_frame_knn_index(self, frame_vec): + def _update_frame_knn_index(self, frame_vec: NDArray) -> None: - if self.get_levels_sum() >= 22 and not self.levels_satisfied: - self.levels_satisfied = True - self.base_explore = self.knn_index.get_current_count() - self.init_knn() + if self._get_levels_sum() >= 22 and not self._levels_satisfied: + self._levels_satisfied = True + self._base_explore = self._knn_index.get_current_count() + self._init_knn() - if self.knn_index.get_current_count() == 0: + if self._knn_index.get_current_count() == 0: # if index is empty add current frame - self.knn_index.add_items( - frame_vec, np.array([self.knn_index.get_current_count()]) + self._knn_index.add_items( + frame_vec, np.array([self._knn_index.get_current_count()]) ) else: # check for nearest frame and add if current - labels, distances = self.knn_index.knn_query(frame_vec, k = 1) - if distances[0][0] > self.similar_frame_dist: + labels, distances = self._knn_index.knn_query(frame_vec, k=1, filter=None) + if distances[0][0] > self._similar_frame_dist: # print(f"distances[0][0] : {distances[0][0]} similar_frame_dist : {self.similar_frame_dist}") - self.knn_index.add_items( - frame_vec, np.array([self.knn_index.get_current_count()]) + self._knn_index.add_items( + frame_vec, np.array([self._knn_index.get_current_count()]) ) - def update_seen_coords(self): - x_pos = self.read_m(0xD362) - y_pos = self.read_m(0xD361) - map_n = self.read_m(0xD35E) + def _update_seen_coords(self) -> None: + x_pos = self._read_m(0xD362) + y_pos = self._read_m(0xD361) + map_n = self._read_m(0xD35E) coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" - if self.get_levels_sum() >= 22 and not self.levels_satisfied: - self.levels_satisfied = True - self.base_explore = len(self.seen_coords) - self.seen_coords = {} + if self._get_levels_sum() >= 22 and not self._levels_satisfied: + self._levels_satisfied = True + self._base_explore = len(self._seen_coords) + self._seen_coords = {} - self.seen_coords[coord_string] = self.step_count + self._seen_coords[coord_string] = self._step_count - def update_reward(self): + def _update_reward(self) -> tuple[float, tuple[float, float, float]]: # compute reward - old_prog = self.group_rewards() - self.progress_reward = self.get_game_state_reward() - new_prog = self.group_rewards() - new_total = sum([val for _, val in self.progress_reward.items()]) #sqrt(self.explore_reward * self.progress_reward) - new_step = new_total - self.total_reward - if new_step < 0 and self.read_hp_fraction() > 0: + old_prog = self._group_rewards() + self._progress_reward = self._get_game_state_reward() + new_prog = self._group_rewards() + new_total = sum([val for _, val in self._progress_reward.items()]) #sqrt(self.explore_reward * self.progress_reward) + new_step = new_total - self._total_reward + if new_step < 0 and self._read_hp_fraction() > 0: #print(f'\n\nreward went down! {self.progress_reward}\n\n') - self.save_screenshot('neg_reward') + self._save_screenshot('neg_reward') - self.total_reward = new_total + self._total_reward = new_total return (new_step, (new_prog[0]-old_prog[0], new_prog[1]-old_prog[1], new_prog[2]-old_prog[2]) ) - def group_rewards(self): - prog = self.progress_reward + def _group_rewards(self) -> tuple[float, float, float]: + prog = self._progress_reward # these values are only used by memory - return (prog['level'] * 100 / self.reward_scale, - self.read_hp_fraction()*2000, - prog['explore'] * 150 / (self.explore_weight * self.reward_scale)) + return (prog['level'] * 100 / self._reward_scale, + self._read_hp_fraction() * 2000, + prog['explore'] * 150 / (self._explore_weight * self._reward_scale)) #(prog['events'], # prog['levels'] + prog['party_xp'], # prog['explore']) - def create_exploration_memory(self): - w = self.output_shape[1] - h = self.memory_height + def _create_exploration_memory(self) -> NDArray: + """ + Prepares the part of image with colored bars for level, hp and exploration rewards. + :return: a [h, w, 3] NDArray(uint8) interpretable as an image + """ + w = self._output_shape[1] + h = self._memory_height - def make_reward_channel(r_val): - col_steps = self.col_steps + def make_reward_channel(r_val: float) -> NDArray: + col_steps = self._col_steps max_r_val = (w-1) * h * col_steps # truncate progress bar. if hitting this # you should scale down the reward in group_rewards! @@ -362,122 +426,122 @@ def make_reward_channel(r_val): memory[col, row] = last_pixel * (255 // col_steps) return memory - level, hp, explore = self.group_rewards() + level, hp, explore = self._group_rewards() full_memory = np.stack(( make_reward_channel(level), make_reward_channel(hp), make_reward_channel(explore) ), axis=-1) - if self.get_badges() > 0: + if self._get_badges() > 0: full_memory[:, -1, :] = 255 return full_memory - def create_recent_memory(self): + def _create_recent_memory(self) -> NDArray: return rearrange( - self.recent_memory, + self._recent_memory, '(w h) c -> h w c', - h=self.memory_height) + h=self._memory_height) - def check_if_done(self): - if self.early_stopping: + def _check_if_done(self) -> bool: + if self._early_stopping: done = False - if self.step_count > 128 and self.recent_memory.sum() < (255 * 1): + if self._step_count > 128 and self._recent_memory.sum() < (255 * 1): done = True else: - done = self.step_count >= self.max_steps + done = self._step_count >= self._max_steps #done = self.read_hp_fraction() == 0 return done - def save_and_print_info(self, done, obs_memory): - if self.print_rewards: - prog_string = f'step: {self.step_count:6d}' - for key, val in self.progress_reward.items(): + def _save_and_print_info(self, done: bool, obs_memory: NDArray) -> None: + if self._print_rewards: + prog_string = f'step: {self._step_count:6d}' + for key, val in self._progress_reward.items(): prog_string += f' {key}: {val:5.2f}' - prog_string += f' sum: {self.total_reward:5.2f}' + prog_string += f' sum: {self._total_reward:5.2f}' print(f'\r{prog_string}', end='', flush=True) - if self.step_count % 50 == 0: + if self._step_count % 50 == 0: plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self._s_path / Path(f'curframe_{self._instance_id}.jpeg'), self.render(reduce_res=False)) - if self.print_rewards and done: + if self._print_rewards and done: print('', flush=True) - if self.save_final_state: - fs_path = self.s_path / Path('final_states') + if self._save_final_state: + fs_path = self._s_path / Path('final_states') fs_path.mkdir(exist_ok=True) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), + fs_path / Path(f'frame_r{self._total_reward:.4f}_{self._reset_count}_small.jpeg'), obs_memory) plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), + fs_path / Path(f'frame_r{self._total_reward:.4f}_{self._reset_count}_full.jpeg'), self.render(reduce_res=False)) - if self.save_video and done: - self.full_frame_writer.close() - self.model_frame_writer.close() + if self._save_video and done: + self._full_frame_writer.close() + self._model_frame_writer.close() if done: - self.all_runs.append(self.progress_reward) - with open(self.s_path / Path(f'all_runs_{self.instance_id}.json'), 'w') as f: - json.dump(self.all_runs, f) - pd.DataFrame(self.agent_stats).to_csv( - self.s_path / Path(f'agent_stats_{self.instance_id}.csv.gz'), compression='gzip', mode='a') + self._all_runs.append(self._progress_reward) + with open(self._s_path / Path(f'all_runs_{self._instance_id}.json'), 'w') as f: + json.dump(self._all_runs, f) + pd.DataFrame(self._agent_stats).to_csv( + self._s_path / Path(f'agent_stats_{self._instance_id}.csv.gz'), compression='gzip', mode='a') - def read_m(self, addr): - return self.pyboy.get_memory_value(addr) + def _read_m(self, addr: int) -> int: + return self._pyboy.get_memory_value(addr) - def read_bit(self, addr, bit: int) -> bool: + def _read_bit(self, addr: int, bit: int) -> bool: # add padding so zero will read '0b100000000' instead of '0b0' - return bin(256 + self.read_m(addr))[-bit-1] == '1' + return bin(256 + self._read_m(addr))[-bit - 1] == '1' - def get_levels_sum(self): - poke_levels = [max(self.read_m(a) - 2, 0) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] + def _get_levels_sum(self) -> int: + poke_levels = [max(self._read_m(a) - 2, 0) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] return max(sum(poke_levels) - 4, 0) # subtract starting pokemon level - def get_levels_reward(self): + def _get_levels_reward(self) -> float: explore_thresh = 22 scale_factor = 4 - level_sum = self.get_levels_sum() + level_sum = self._get_levels_sum() if level_sum < explore_thresh: scaled = level_sum else: scaled = (level_sum-explore_thresh) / scale_factor + explore_thresh - self.max_level_rew = max(self.max_level_rew, scaled) - return self.max_level_rew + self._max_level_rew = max(self._max_level_rew, scaled) + return self._max_level_rew - def get_knn_reward(self): + def _get_knn_reward(self) -> float: - pre_rew = self.explore_weight * 0.005 - post_rew = self.explore_weight * 0.01 - cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) - base = (self.base_explore if self.levels_satisfied else cur_size) * pre_rew - post = (cur_size if self.levels_satisfied else 0) * post_rew + pre_rew = self._explore_weight * 0.005 + post_rew = self._explore_weight * 0.01 + cur_size = self._knn_index.get_current_count() if self._use_screen_explore else len(self._seen_coords) + base = (self._base_explore if self._levels_satisfied else cur_size) * pre_rew + post = (cur_size if self._levels_satisfied else 0) * post_rew return base + post - def get_badges(self): - return self.bit_count(self.read_m(0xD356)) + def _get_badges(self) -> int: + return self._bit_count(self._read_m(0xD356)) - def read_party(self): - return [self.read_m(addr) for addr in [0xD164, 0xD165, 0xD166, 0xD167, 0xD168, 0xD169]] + def _read_party(self) -> list[int]: + return [self._read_m(addr) for addr in [0xD164, 0xD165, 0xD166, 0xD167, 0xD168, 0xD169]] - def update_heal_reward(self): - cur_health = self.read_hp_fraction() + def _update_heal_reward(self) -> None: + cur_health = self._read_hp_fraction() # if health increased and party size did not change - if (cur_health > self.last_health and - self.read_m(0xD163) == self.party_size): - if self.last_health > 0: - heal_amount = cur_health - self.last_health + if (cur_health > self._last_health and + self._read_m(0xD163) == self._party_size): + if self._last_health > 0: + heal_amount = cur_health - self._last_health if heal_amount > 0.5: print(f'healed: {heal_amount}') - self.save_screenshot('healing') - self.total_healing_rew += heal_amount * 4 + self._save_screenshot('healing') + self._total_healing_rew += heal_amount * 4 else: - self.died_count += 1 - - def get_all_events_reward(self): + self._died_count += 1 + + def _get_all_events_reward(self) -> int: # adds up all event flags, exclude museum ticket event_flags_start = 0xD747 event_flags_end = 0xD886 @@ -486,16 +550,16 @@ def get_all_events_reward(self): return max( sum( [ - self.bit_count(self.read_m(i)) + self._bit_count(self._read_m(i)) for i in range(event_flags_start, event_flags_end) ] ) - base_event_flags - - int(self.read_bit(museum_ticket[0], museum_ticket[1])), - 0, - ) + - int(self._read_bit(museum_ticket[0], museum_ticket[1])), + 0, + ) - def get_game_state_reward(self, print_stats=False): + def _get_game_state_reward(self) -> dict[str, float]: # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm ''' @@ -520,67 +584,70 @@ def get_game_state_reward(self, print_stats=False): print(f'oak_parcel: {oak_parcel} oak_pokedex: {oak_pokedex} all_events_score: {all_events_score}') ''' - state_scores = { - 'event': self.reward_scale*self.update_max_event_rew(), + state_scores: dict[str, float] = { + 'event': self._reward_scale * self._update_max_event_rew(), #'party_xp': self.reward_scale*0.1*sum(poke_xps), - 'level': self.reward_scale*self.get_levels_reward(), - 'heal': self.reward_scale*self.total_healing_rew, - 'op_lvl': self.reward_scale*self.update_max_op_level(), - 'dead': self.reward_scale*-0.1*self.died_count, - 'badge': self.reward_scale*self.get_badges() * 5, + 'level': self._reward_scale * self._get_levels_reward(), + 'heal': self._reward_scale * self._total_healing_rew, + 'op_lvl': self._reward_scale * self._update_max_op_level(), + 'dead': self._reward_scale * -0.1 * self._died_count, + 'badge': self._reward_scale * self._get_badges() * 5, #'op_poke': self.reward_scale*self.max_opponent_poke * 800, #'money': self.reward_scale* money * 3, #'seen_poke': self.reward_scale * seen_poke_count * 400, - 'explore': self.reward_scale * self.get_knn_reward() + 'explore': self._reward_scale * self._get_knn_reward() } return state_scores - def save_screenshot(self, name): - ss_dir = self.s_path / Path('screenshots') + def _save_screenshot(self, name: str) -> None: + ss_dir = self._s_path / Path('screenshots') ss_dir.mkdir(exist_ok=True) plt.imsave( - ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), + ss_dir / Path(f'frame{self._instance_id}_r{self._total_reward:.4f}_{self._reset_count}_{name}.jpeg'), self.render(reduce_res=False)) - def update_max_op_level(self): + def _update_max_op_level(self) -> float: #opponent_level = self.read_m(0xCFE8) - 5 # base level - opponent_level = max([self.read_m(a) for a in [0xD8C5, 0xD8F1, 0xD91D, 0xD949, 0xD975, 0xD9A1]]) - 5 + opponent_level = max([self._read_m(a) for a in [0xD8C5, 0xD8F1, 0xD91D, 0xD949, 0xD975, 0xD9A1]]) - 5 #if opponent_level >= 7: # self.save_screenshot('highlevelop') - self.max_opponent_level = max(self.max_opponent_level, opponent_level) - return self.max_opponent_level * 0.2 + self._max_opponent_level = max(self._max_opponent_level, opponent_level) + return self._max_opponent_level * 0.2 - def update_max_event_rew(self): - cur_rew = self.get_all_events_reward() - self.max_event_rew = max(cur_rew, self.max_event_rew) - return self.max_event_rew - - def read_hp_fraction(self): - hp_sum = sum([self.read_hp(add) for add in [0xD16C, 0xD198, 0xD1C4, 0xD1F0, 0xD21C, 0xD248]]) - max_hp_sum = sum([self.read_hp(add) for add in [0xD18D, 0xD1B9, 0xD1E5, 0xD211, 0xD23D, 0xD269]]) + def _update_max_event_rew(self) -> int: + cur_rew = self._get_all_events_reward() + self._max_event_rew = max(cur_rew, self._max_event_rew) + return self._max_event_rew + + def _read_hp_fraction(self) -> float: + hp_sum = sum([self._read_hp(add) for add in [0xD16C, 0xD198, 0xD1C4, 0xD1F0, 0xD21C, 0xD248]]) + max_hp_sum = sum([self._read_hp(add) for add in [0xD18D, 0xD1B9, 0xD1E5, 0xD211, 0xD23D, 0xD269]]) max_hp_sum = max(max_hp_sum, 1) return hp_sum / max_hp_sum - def read_hp(self, start): - return 256 * self.read_m(start) + self.read_m(start+1) + def _read_hp(self, start: int) -> int: + return 256 * self._read_m(start) + self._read_m(start + 1) # built-in since python 3.10 - def bit_count(self, bits): + @staticmethod + def _bit_count(bits: int) -> int: return bin(bits).count('1') - def read_triple(self, start_add): - return 256*256*self.read_m(start_add) + 256*self.read_m(start_add+1) + self.read_m(start_add+2) + def _read_triple(self, start_add: int) -> int: + return 256*256*self._read_m(start_add) + 256*self._read_m(start_add + 1) + self._read_m(start_add + 2) - def read_bcd(self, num): + @staticmethod + def _from_bcd(num: int) -> int: return 10 * ((num >> 4) & 0x0f) + (num & 0x0f) - def read_money(self): - return (100 * 100 * self.read_bcd(self.read_m(0xD347)) + - 100 * self.read_bcd(self.read_m(0xD348)) + - self.read_bcd(self.read_m(0xD349))) + def _read_money(self) -> int: + return (100 * 100 * self._from_bcd(self._read_m(0xD347)) + + 100 * self._from_bcd(self._read_m(0xD348)) + + self._from_bcd(self._read_m(0xD349))) - def get_map_location(self, map_idx): + @staticmethod + def _get_map_location(map_idx: int) -> str: map_locations = { 0: "Pallet Town", 1: "Viridian City", diff --git a/baselines/render_all_needed_grids.py b/baselines/render_all_needed_grids.py index 269fc56e2..e96e46ea2 100644 --- a/baselines/render_all_needed_grids.py +++ b/baselines/render_all_needed_grids.py @@ -1,30 +1,29 @@ +import sys from os.path import exists from pathlib import Path -import sys -import uuid -from red_gym_env import RedGymEnv -from stable_baselines3 import A2C, PPO -from stable_baselines3.common import env_checker -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -from stable_baselines3.common.utils import set_random_seed +from typing import Callable + +from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import SubprocVecEnv + +from red_gym_env import RedGymEnv, RedGymEnvConfig -def make_env(rank, env_conf, seed=0): + +def make_env(env_conf: RedGymEnvConfig, seed: int = 0) -> Callable[[], RedGymEnv]: """ Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses + :param env_conf: (dict) various environment config parameters :param seed: (int) the initial seed for RNG - :param rank: (int) index of the subprocess """ def _init(): - env = RedGymEnv(env_conf) - env.seed(seed + rank) - return env + return RedGymEnv(env_conf) set_random_seed(seed) return _init -def run_save(save): + +def run_save(save: str): save = Path(save) ep_length = 2048 * 8 sess_path = f'grid_renders/session_{save.stem}' @@ -35,7 +34,7 @@ def run_save(save): 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0 } num_cpu = 40 # Also sets the number of episodes per training iteration - env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) + env = SubprocVecEnv([make_env(env_config) for _ in range(num_cpu)]) checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke') diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index f4423a3a5..f7e46d4aa 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,20 +1,21 @@ +import uuid from os.path import exists from pathlib import Path -import uuid -from red_gym_env import RedGymEnv -from stable_baselines3 import A2C, PPO -from stable_baselines3.common import env_checker -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -from stable_baselines3.common.utils import set_random_seed + +from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import SubprocVecEnv + +from red_gym_env import RedGymEnv, RedGymEnvConfig -def make_env(rank, env_conf, seed=0): + +def make_env(rank: int, env_conf: RedGymEnvConfig, seed: int = 0): """ Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses - :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess + :param env_conf: (dict) various environment config parameters + :param seed: (int) the initial seed for RNG """ def _init(): env = RedGymEnv(env_conf) @@ -23,13 +24,13 @@ def _init(): set_random_seed(seed) return _init -if __name__ == '__main__': +if __name__ == '__main__': ep_length = 2048 * 8 sess_path = Path(f'session_{str(uuid.uuid4())[:8]}') - env_config = { + env_config: RedGymEnvConfig = { 'headless': True, 'save_final_state': True, 'early_stop': False, 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, @@ -41,7 +42,7 @@ def _init(): num_cpu = 44 #64 #46 # Also sets the number of episodes per training iteration env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, + checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=str(sess_path), name_prefix='poke') #env_checker.check_env(env) learn_steps = 40 diff --git a/baselines/run_baseline_parallel_fast.py b/baselines/run_baseline_parallel_fast.py index 78f6ec95e..9c0586103 100644 --- a/baselines/run_baseline_parallel_fast.py +++ b/baselines/run_baseline_parallel_fast.py @@ -1,23 +1,23 @@ import uuid from os.path import exists from pathlib import Path +from typing import Callable from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.vec_env import SubprocVecEnv -from red_gym_env import RedGymEnv +from red_gym_env import RedGymEnv, RedGymEnvConfig from tensorboard_callback import TensorboardCallback -def make_env(rank, env_conf, seed=0): +def make_env(rank: int, env_conf: RedGymEnvConfig, seed: int = 0) -> Callable[[], RedGymEnv]: """ Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses - :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess + :param env_conf: (dict) various environment config parameters + :param seed: (int) the initial seed for RNG """ def _init(): env = RedGymEnv(env_conf) @@ -26,6 +26,7 @@ def _init(): set_random_seed(seed) return _init + if __name__ == '__main__': use_wandb_logging = False @@ -33,7 +34,7 @@ def _init(): sess_id = str(uuid.uuid4())[:8] sess_path = Path(f'session_{sess_id}') - env_config = { + env_config: RedGymEnvConfig = { 'headless': True, 'save_final_state': True, 'early_stop': False, 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, @@ -47,7 +48,7 @@ def _init(): num_cpu = 16 # Also sets the number of episodes per training iteration env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) - checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, + checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=str(sess_path), name_prefix='poke') callbacks = [checkpoint_callback, TensorboardCallback()] @@ -78,7 +79,7 @@ def _init(): model.rollout_buffer.n_envs = num_cpu model.rollout_buffer.reset() else: - model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length // 8, batch_size=128, n_epochs=3, gamma=0.998, tensorboard_log=sess_path) + model = PPO('CnnPolicy', env, verbose=1, n_steps=ep_length // 8, batch_size=128, n_epochs=3, gamma=0.998, tensorboard_log=str(sess_path)) for i in range(learn_steps): model.learn(total_timesteps=(ep_length)*num_cpu*1000, callback=CallbackList(callbacks)) diff --git a/baselines/run_pretrained_interactive.py b/baselines/run_pretrained_interactive.py index 2c2671332..7172c4ddc 100644 --- a/baselines/run_pretrained_interactive.py +++ b/baselines/run_pretrained_interactive.py @@ -1,34 +1,32 @@ -from os.path import exists -from pathlib import Path import uuid -from red_gym_env import RedGymEnv -from stable_baselines3 import A2C, PPO -from stable_baselines3.common import env_checker -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +from pathlib import Path +from typing import Callable + +from stable_baselines3 import PPO from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.callbacks import CheckpointCallback -def make_env(rank, env_conf, seed=0): +from red_gym_env import RedGymEnv +from red_gym_env import RedGymEnvConfig + + +def make_env(env_conf: RedGymEnvConfig, seed=0) -> Callable[[], RedGymEnv]: """ Utility function for multiprocessed env. - :param env_id: (str) the environment ID - :param num_env: (int) the number of environments you wish to have in subprocesses + :param env_conf: (dict) various environment config parameters :param seed: (int) the initial seed for RNG - :param rank: (int) index of the subprocess """ def _init(): - env = RedGymEnv(env_conf) - #env.seed(seed + rank) - return env + return RedGymEnv(env_conf) set_random_seed(seed) return _init + if __name__ == '__main__': sess_path = Path(f'session_{str(uuid.uuid4())[:8]}') ep_length = 2**23 - env_config = { + env_config: RedGymEnvConfig = { 'headless': False, 'save_final_state': True, 'early_stop': False, 'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length, 'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path, @@ -36,7 +34,7 @@ def _init(): } num_cpu = 1 #64 #46 # Also sets the number of episodes per training iteration - env = make_env(0, env_config)() #SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) + env = make_env(env_config)() #SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) #env_checker.check_env(env) file_name = 'session_4da05e87_main_good/poke_439746560_steps' diff --git a/baselines/run_recorded_actions.py b/baselines/run_recorded_actions.py index 31d2d033b..896815d61 100644 --- a/baselines/run_recorded_actions.py +++ b/baselines/run_recorded_actions.py @@ -18,7 +18,7 @@ def run_recorded_actions_on_emulator_and_save_video(sess_id, instance_id, run_in 'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0, 'instance_id': f'{instance_id}_recorded' } env = RedGymEnv(env_config) - env.reset_count = run_index + env._reset_count = run_index obs = env.reset() for action in action_list: From fc8dc4b09923831bcbc0e3d8dfacfb70fda6e52f Mon Sep 17 00:00:00 2001 From: Bartosz Prusak Date: Mon, 27 Nov 2023 00:42:10 +0100 Subject: [PATCH 3/5] `check_if_done()` is not private --- baselines/red_gym_env.py | 4 ++-- baselines/run_baseline_parallel_fast.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 6c393d66e..1bfa8d90a 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -280,7 +280,7 @@ def step(self, action: int) -> tuple[NDArray, float, bool, bool, dict[str, Any]] self._recent_memory[0, 1] = min(new_prog[1] * 64, 255) self._recent_memory[0, 2] = min(new_prog[2] * 128, 255) - step_limit_reached = self._check_if_done() + step_limit_reached = self.check_if_done() self._save_and_print_info(step_limit_reached, obs_memory) @@ -444,7 +444,7 @@ def _create_recent_memory(self) -> NDArray: '(w h) c -> h w c', h=self._memory_height) - def _check_if_done(self) -> bool: + def check_if_done(self) -> bool: if self._early_stopping: done = False if self._step_count > 128 and self._recent_memory.sum() < (255 * 1): diff --git a/baselines/run_baseline_parallel_fast.py b/baselines/run_baseline_parallel_fast.py index 9c0586103..7ab4a1247 100644 --- a/baselines/run_baseline_parallel_fast.py +++ b/baselines/run_baseline_parallel_fast.py @@ -45,7 +45,7 @@ def _init(): print(env_config) - num_cpu = 16 # Also sets the number of episodes per training iteration + num_cpu = 4 # 16 # Also sets the number of episodes per training iteration env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=str(sess_path), From 8337960b52d096dcea43ad66711529ef6632e05f Mon Sep 17 00:00:00 2001 From: Bartosz Prusak Date: Mon, 27 Nov 2023 00:47:24 +0100 Subject: [PATCH 4/5] `agent_stats` is not private (damn you, reflections) Also added some missed type hints --- baselines/red_gym_env.py | 8 ++++---- baselines/render_all_needed_grids.py | 2 +- baselines/run_baseline_parallel.py | 3 ++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/baselines/red_gym_env.py b/baselines/red_gym_env.py index 1bfa8d90a..b88b47250 100644 --- a/baselines/red_gym_env.py +++ b/baselines/red_gym_env.py @@ -152,7 +152,7 @@ def __init__(self, config: Optional[RedGymEnvConfig] = None): # Fields set in reset() self._recent_memory: NDArray = np.zeros((1,), dtype=np.uint8) self._recent_frames: NDArray = np.zeros((1,), dtype=np.uint8) - self._agent_stats: list[_AgentStats] = [] + self.agent_stats: list[_AgentStats] = [] self._full_frame_writer: Optional[media.VideoWriter] = None self._model_frame_writer: Optional[media.VideoWriter] = None self._levels_satisfied: bool = False @@ -191,7 +191,7 @@ def reset(self, *, self._output_shape[1], self._output_shape[2]), dtype=np.uint8) - self._agent_stats = [] + self.agent_stats = [] if self._save_video: base_dir = self._s_path / Path('rollouts') @@ -326,7 +326,7 @@ def _append_agent_stats(self, action: int) -> None: expl = ('frames', self._knn_index.get_current_count()) else: expl = ('coord_count', len(self._seen_coords)) - self._agent_stats.append({ + self.agent_stats.append({ 'step': self._step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, 'map_location': self._get_map_location(map_n), 'last_action': action, @@ -487,7 +487,7 @@ def _save_and_print_info(self, done: bool, obs_memory: NDArray) -> None: self._all_runs.append(self._progress_reward) with open(self._s_path / Path(f'all_runs_{self._instance_id}.json'), 'w') as f: json.dump(self._all_runs, f) - pd.DataFrame(self._agent_stats).to_csv( + pd.DataFrame(self.agent_stats).to_csv( self._s_path / Path(f'agent_stats_{self._instance_id}.csv.gz'), compression='gzip', mode='a') def _read_m(self, addr: int) -> int: diff --git a/baselines/render_all_needed_grids.py b/baselines/render_all_needed_grids.py index e96e46ea2..47e6df8ac 100644 --- a/baselines/render_all_needed_grids.py +++ b/baselines/render_all_needed_grids.py @@ -23,7 +23,7 @@ def _init(): return _init -def run_save(save: str): +def run_save(save: str) -> None: save = Path(save) ep_length = 2048 * 8 sess_path = f'grid_renders/session_{save.stem}' diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index f7e46d4aa..49fddd794 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,6 +1,7 @@ import uuid from os.path import exists from pathlib import Path +from typing import Callable from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback @@ -10,7 +11,7 @@ from red_gym_env import RedGymEnv, RedGymEnvConfig -def make_env(rank: int, env_conf: RedGymEnvConfig, seed: int = 0): +def make_env(rank: int, env_conf: RedGymEnvConfig, seed: int = 0) -> Callable[[], RedGymEnv]: """ Utility function for multiprocessed env. :param rank: (int) index of the subprocess From 5e8fe5cef01f297c7d863dfebd05d3dc367827e5 Mon Sep 17 00:00:00 2001 From: Bartosz Prusak Date: Thu, 30 Nov 2023 19:00:22 +0100 Subject: [PATCH 5/5] Revert the default num of cores, ooops --- baselines/run_baseline_parallel_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/run_baseline_parallel_fast.py b/baselines/run_baseline_parallel_fast.py index 7ab4a1247..9c0586103 100644 --- a/baselines/run_baseline_parallel_fast.py +++ b/baselines/run_baseline_parallel_fast.py @@ -45,7 +45,7 @@ def _init(): print(env_config) - num_cpu = 4 # 16 # Also sets the number of episodes per training iteration + num_cpu = 16 # Also sets the number of episodes per training iteration env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)]) checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=str(sess_path),