From f5bea8e00db60de4dbeba33953ab54b0863dc7b8 Mon Sep 17 00:00:00 2001 From: Sourcery AI Date: Fri, 26 Jan 2024 10:57:09 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- agents/agent.py | 10 ++++------ agents/exploration_bonus.py | 8 ++------ agents/replay_memory.py | 6 +++--- agents/training.py | 7 +++---- atari/atari.py | 2 +- main.py | 6 ++---- networks/dqn.py | 4 ++-- networks/factory.py | 2 +- networks/inputs.py | 37 +++++++++++++++---------------------- networks/reward_scaling.py | 5 +---- test/test_replay_memory.py | 5 ++--- util/summary.py | 5 +---- util/util.py | 16 ++++++++-------- 13 files changed, 45 insertions(+), 68 deletions(-) diff --git a/agents/agent.py b/agents/agent.py index d07d3f0..58e7973 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -22,14 +22,12 @@ def new_game(self): return observation, reward, done def action(self, session, step, observation): - # Epsilon greedy exploration/exploitation even for bootstrapped DQN if np.random.rand() < self.epsilon(step): return self.atari.sample_action() - else: - [action] = session.run( - self.policy_network.choose_action, - {self.policy_network.inputs.observations: [observation]}) - return action + [action] = session.run( + self.policy_network.choose_action, + {self.policy_network.inputs.observations: [observation]}) + return action def epsilon(self, step): """Epsilon is linearly annealed from an initial exploration value diff --git a/agents/exploration_bonus.py b/agents/exploration_bonus.py index a3b830f..1c5ceaa 100644 --- a/agents/exploration_bonus.py +++ b/agents/exploration_bonus.py @@ -20,12 +20,8 @@ def bonus(self, observation): prob = self.update_density_model(frame) recoding_prob = self.density_model_probability(frame) pseudo_count = prob * (1 - recoding_prob) / (recoding_prob - prob) - if pseudo_count < 0: - pseudo_count = 0 # Occasionally happens at start of training - - # Return exploration bonus - exploration_bonus = self.beta / math.sqrt(pseudo_count + 0.01) - return exploration_bonus + pseudo_count = max(pseudo_count, 0) + return self.beta / math.sqrt(pseudo_count + 0.01) def update_density_model(self, frame): return self.sum_pixel_probabilities(frame, self.density_model.update) diff --git a/agents/replay_memory.py b/agents/replay_memory.py index 88ac16f..2483397 100644 --- a/agents/replay_memory.py +++ b/agents/replay_memory.py @@ -44,7 +44,7 @@ def __init__(self, config): elif config.replay_priorities == 'proportional': self.priorities = ProportionalPriorities(config) else: - raise Exception('Unknown replay_priorities: ' + config.replay_priorities) + raise Exception(f'Unknown replay_priorities: {config.replay_priorities}') def store_new_episode(self, observation): for frame in observation: @@ -133,7 +133,7 @@ def valid_indices(self, new_indices, input_range, indices=None): return np.unique(np.append(valid_indices, indices)) def save(self): - name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf' + name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf' with h5py.File(name, 'w') as h5f: for key, value in self.__dict__.items(): if key == 'priorities': @@ -144,7 +144,7 @@ def save(self): h5f.create_dataset(key, data=value) def load(self): - name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf' + name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf' with h5py.File(name, 'r') as h5f: for key in self.__dict__.keys(): if key == 'priorities': diff --git a/agents/training.py b/agents/training.py index 3cbb3c1..13440bb 100644 --- a/agents/training.py +++ b/agents/training.py @@ -81,15 +81,14 @@ def train_agent(self, session, agent): agent.replay_memory.save() def reset_target_network(self, session, step): - if self.reset_op: - if step > 0 and step % self.config.target_network_update_period == 0: + if step > 0 and step % self.config.target_network_update_period == 0: + if self.reset_op: session.run(self.reset_op) def train_batch(self, session, replay_memory, step): fetches = [self.global_step, self.train_op] + self.summary.operation(step) - batch = replay_memory.sample_batch(fetches, self.config.batch_size) - if batch: + if batch := replay_memory.sample_batch(fetches, self.config.batch_size): step, priorities, summary = session.run(fetches, batch.feed_dict()) batch.update_priorities(priorities) self.summary.add_summary(summary, step) diff --git a/atari/atari.py b/atari/atari.py index 7a4c06f..2504646 100644 --- a/atari/atari.py +++ b/atari/atari.py @@ -44,7 +44,7 @@ def reset(self): if self.render: self.env.render() self.frames = [] - for i in range(np.random.randint(self.input_frames, self.max_noops + 1)): + for _ in range(np.random.randint(self.input_frames, self.max_noops + 1)): frame, reward_, done, _ = self.env.step(0) if self.render: self.env.render() diff --git a/main.py b/main.py index 15384db..5bd8b47 100644 --- a/main.py +++ b/main.py @@ -155,12 +155,10 @@ def create_config(): if config.async == 'one_step': config.batch_size = config.train_period - elif config.async == 'n_step': - config.batch_size = 1 - elif config.async == 'a3c': + elif config. async in ['n_step', 'a3c']: config.batch_size = 1 else: - raise Exception('Unknown asynchronous algorithm: ' + config.async) + raise Exception(f'Unknown asynchronous algorithm: {config.async}') config.n_step = config.async == 'n_step' config.actor_critic = config.async == 'a3c' diff --git a/networks/dqn.py b/networks/dqn.py index a826a85..b585624 100644 --- a/networks/dqn.py +++ b/networks/dqn.py @@ -144,8 +144,8 @@ def variables(self): def activation_summary(self, tensor): if self.write_summaries: tensor_name = tensor.op.name - tf.summary.histogram(tensor_name + '/activations', tensor) - tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(tensor)) + tf.summary.histogram(f'{tensor_name}/activations', tensor) + tf.summary.scalar(f'{tensor_name}/sparsity', tf.nn.zero_fraction(tensor)) class ActionValueHead(object): diff --git a/networks/factory.py b/networks/factory.py index 737786b..a9cfa88 100644 --- a/networks/factory.py +++ b/networks/factory.py @@ -108,7 +108,7 @@ def create_summary_ops(self, loss, variables, gradients): for grad, var in gradients: if grad is not None: - tf.summary.histogram('gradient/' + var.name, grad) + tf.summary.histogram(f'gradient/{var.name}', grad) self.summary.create_summary_op() diff --git a/networks/inputs.py b/networks/inputs.py index 05236cf..0bd41fb 100644 --- a/networks/inputs.py +++ b/networks/inputs.py @@ -66,9 +66,7 @@ def offset_data(t, name): input_len = shape[0] if not hasattr(placeholder, 'zero_offset'): placeholder.zero_offset = tf.placeholder_with_default( - input_len - 1, # If no zero_offset is given assume that t = 0 - (), - name + '/zero_offset') + input_len - 1, (), f'{name}/zero_offset') end = t + 1 start = end - input_len @@ -100,11 +98,7 @@ def __init__(self, inputs, t): class RequiredFeeds(object): def __init__(self, placeholder=None, time_offsets=0, feeds=None): - if feeds: - self.feeds = feeds - else: - self.feeds = {} - + self.feeds = feeds if feeds else {} if placeholder is None: return @@ -153,21 +147,20 @@ def required_feeds(cls, tensor): if hasattr(tensor, 'required_feeds'): # Return cached result return tensor.required_feeds + # Get feeds required by all inputs + if isinstance(tensor, list): + input_tensors = tensor else: - # Get feeds required by all inputs - if isinstance(tensor, list): - input_tensors = tensor - else: - op = tensor if isinstance(tensor, tf.Operation) else tensor.op - input_tensors = list(op.inputs) + list(op.control_inputs) + op = tensor if isinstance(tensor, tf.Operation) else tensor.op + input_tensors = list(op.inputs) + list(op.control_inputs) - from networks import inputs - feeds = inputs.RequiredFeeds() - for input_tensor in input_tensors: - feeds = feeds.merge(cls.required_feeds(input_tensor)) + from networks import inputs + feeds = inputs.RequiredFeeds() + for input_tensor in input_tensors: + feeds = feeds.merge(cls.required_feeds(input_tensor)) - # Cache results - if not isinstance(tensor, list): - tensor.required_feeds = feeds + # Cache results + if not isinstance(tensor, list): + tensor.required_feeds = feeds - return feeds + return feeds diff --git a/networks/reward_scaling.py b/networks/reward_scaling.py index a35cf56..8b31078 100644 --- a/networks/reward_scaling.py +++ b/networks/reward_scaling.py @@ -31,10 +31,7 @@ def batch_sigma_squared(self, batch): self.v = (1 - self.beta) * self.v + self.beta * average_square_reward sigma_squared = (self.v - self.mu**2) / self.variance - if sigma_squared > 0: - return sigma_squared - else: - return 1.0 + return sigma_squared if sigma_squared > 0 else 1.0 def unnormalize_output(self, output): return output * self.scale_weight + self.scale_bias diff --git a/test/test_replay_memory.py b/test/test_replay_memory.py index 825fced..c197c63 100644 --- a/test/test_replay_memory.py +++ b/test/test_replay_memory.py @@ -47,9 +47,8 @@ def test_replay_memory(self): self.assertAllEqual(feed_dict[inputs.alives], [[True, True], [True, False]]) - discounted_reward = sum([ - reward * config.discount_rate**(reward - 4) for reward in range(4, 11) - ]) + discounted_reward = sum(reward * config.discount_rate**(reward - 4) + for reward in range(4, 11)) self.assertNear( feed_dict[inputs.discounted_rewards][0], discounted_reward, err=0.0001) diff --git a/util/summary.py b/util/summary.py index a0f1999..521881f 100644 --- a/util/summary.py +++ b/util/summary.py @@ -27,10 +27,7 @@ def episode(self, step, score, steps, duration): self.summary_writer.add_summary(summary, step) def operation(self, step): - if self.run_summary(step): - return [self.summary_op] - else: - return [self.dummy_summary_op] + return [self.summary_op] if self.run_summary(step) else [self.dummy_summary_op] def add_summary(self, summary, step): if summary: diff --git a/util/util.py b/util/util.py index 929c4ff..0033bf7 100644 --- a/util/util.py +++ b/util/util.py @@ -13,18 +13,18 @@ def find_previous_run(dir): if os.path.isdir(dir): runs = [child[4:] for child in os.listdir(dir) if child[:4] == 'run_'] if runs: - return max([int(run) for run in runs]) + return max(int(run) for run in runs) return 0 if config.run_dir == 'latest': - parent_dir = 'runs/%s/' % config.game + parent_dir = f'runs/{config.game}/' previous_run = find_previous_run(parent_dir) run_dir = parent_dir + ('run_%d' % previous_run) elif config.run_dir: run_dir = config.run_dir else: - parent_dir = 'runs/%s/' % config.game + parent_dir = f'runs/{config.game}/' previous_run = find_previous_run(parent_dir) run_dir = parent_dir + ('run_%d' % (previous_run + 1)) @@ -34,18 +34,18 @@ def find_previous_run(dir): if not os.path.isdir(run_dir): os.makedirs(run_dir) - log('Checkpoint and summary directory is %s' % run_dir) + log(f'Checkpoint and summary directory is {run_dir}') return run_dir def format_offset(prefix, t): if t > 0: - return prefix + '_t_plus_' + str(t) + return f'{prefix}_t_plus_{str(t)}' elif t == 0: - return prefix + '_t' + return f'{prefix}_t' else: - return prefix + '_t_minus_' + str(-t) + return f'{prefix}_t_minus_{str(-t)}' def add_loss_summaries(total_loss): @@ -91,7 +91,7 @@ def log(message): import threading thread_id = threading.current_thread().name now = datetime.strftime(datetime.now(), '%F %X') - print('%s %s: %s' % (now, thread_id, message)) + print(f'{now} {thread_id}: {message}') def memoize(f):