-
Notifications
You must be signed in to change notification settings - Fork 34
Sourcery Starbot ⭐ refactored brendanator/atari-rl #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,8 @@ def bonus(self, observation): | |
| prob = self.update_density_model(frame) | ||
| recoding_prob = self.density_model_probability(frame) | ||
| pseudo_count = prob * (1 - recoding_prob) / (recoding_prob - prob) | ||
| if pseudo_count < 0: | ||
| pseudo_count = 0 # Occasionally happens at start of training | ||
|
|
||
| # Return exploration bonus | ||
| exploration_bonus = self.beta / math.sqrt(pseudo_count + 0.01) | ||
| return exploration_bonus | ||
| pseudo_count = max(pseudo_count, 0) | ||
| return self.beta / math.sqrt(pseudo_count + 0.01) | ||
|
Comment on lines
-23
to
+24
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ): |
||
|
|
||
| def update_density_model(self, frame): | ||
| return self.sum_pixel_probabilities(frame, self.density_model.update) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,7 @@ def __init__(self, config): | |
| elif config.replay_priorities == 'proportional': | ||
| self.priorities = ProportionalPriorities(config) | ||
| else: | ||
| raise Exception('Unknown replay_priorities: ' + config.replay_priorities) | ||
| raise Exception(f'Unknown replay_priorities: {config.replay_priorities}') | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (code-quality): Raise a specific error instead of the general ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
This way, callers of the code can handle the error appropriately. How can you solve this?
So instead of having code raising if incorrect_input(value):
raise Exception("The input is incorrect")you can have code raising a specific error like if incorrect_input(value):
raise ValueError("The input is incorrect")or class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect") |
||
|
|
||
| def store_new_episode(self, observation): | ||
| for frame in observation: | ||
|
|
@@ -133,7 +133,7 @@ def valid_indices(self, new_indices, input_range, indices=None): | |
| return np.unique(np.append(valid_indices, indices)) | ||
|
|
||
| def save(self): | ||
| name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf' | ||
| name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf' | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| with h5py.File(name, 'w') as h5f: | ||
| for key, value in self.__dict__.items(): | ||
| if key == 'priorities': | ||
|
|
@@ -144,7 +144,7 @@ def save(self): | |
| h5f.create_dataset(key, data=value) | ||
|
|
||
| def load(self): | ||
| name = self.run_dir + 'replay_' + threading.current_thread().name + '.hdf' | ||
| name = f'{self.run_dir}replay_{threading.current_thread().name}.hdf' | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| with h5py.File(name, 'r') as h5f: | ||
| for key in self.__dict__.keys(): | ||
| if key == 'priorities': | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -81,15 +81,14 @@ def train_agent(self, session, agent): | |||||||||||||
| agent.replay_memory.save() | ||||||||||||||
|
|
||||||||||||||
| def reset_target_network(self, session, step): | ||||||||||||||
| if self.reset_op: | ||||||||||||||
| if step > 0 and step % self.config.target_network_update_period == 0: | ||||||||||||||
| if step > 0 and step % self.config.target_network_update_period == 0: | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question (llm): The refactoring to combine the conditions into a single if statement is good for readability, but ensure that the logic is equivalent and that self.reset_op is always defined when needed. |
||||||||||||||
| if self.reset_op: | ||||||||||||||
|
Comment on lines
-84
to
+85
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||||||||||||||
| session.run(self.reset_op) | ||||||||||||||
|
Comment on lines
+84
to
86
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): Merge nested if conditions (
Suggested change
ExplanationToo much nesting can make code difficult to understand, and this is especiallytrue in Python, where there are no brackets to help out with the delineation of different nesting levels. Reading deeply nested code is confusing, since you have to keep track of which |
||||||||||||||
|
|
||||||||||||||
| def train_batch(self, session, replay_memory, step): | ||||||||||||||
| fetches = [self.global_step, self.train_op] + self.summary.operation(step) | ||||||||||||||
|
|
||||||||||||||
| batch = replay_memory.sample_batch(fetches, self.config.batch_size) | ||||||||||||||
| if batch: | ||||||||||||||
| if batch := replay_memory.sample_batch(fetches, self.config.batch_size): | ||||||||||||||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||||||||||||||
| step, priorities, summary = session.run(fetches, batch.feed_dict()) | ||||||||||||||
| batch.update_priorities(priorities) | ||||||||||||||
| self.summary.add_summary(summary, step) | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,7 @@ def reset(self): | |
| if self.render: self.env.render() | ||
| self.frames = [] | ||
|
|
||
| for i in range(np.random.randint(self.input_frames, self.max_noops + 1)): | ||
| for _ in range(np.random.randint(self.input_frames, self.max_noops + 1)): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| frame, reward_, done, _ = self.env.step(0) | ||
| if self.render: self.env.render() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,12 +155,10 @@ def create_config(): | |
|
|
||
| if config.async == 'one_step': | ||
| config.batch_size = config.train_period | ||
| elif config.async == 'n_step': | ||
| config.batch_size = 1 | ||
| elif config.async == 'a3c': | ||
| elif config. async in ['n_step', 'a3c']: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (llm): There's an extra space between 'config.' and 'async' which could lead to a syntax error. This should be corrected.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): There is an unintended space between 'config.' and 'async'. The space between 'config.' and 'async' will cause a syntax error; please remove it so it reads 'config.async'. |
||
| config.batch_size = 1 | ||
| else: | ||
| raise Exception('Unknown asynchronous algorithm: ' + config.async) | ||
| raise Exception(f'Unknown asynchronous algorithm: {config.async}') | ||
|
Comment on lines
-158
to
+161
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (code-quality): Raise a specific error instead of the general ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
This way, callers of the code can handle the error appropriately. How can you solve this?
So instead of having code raising if incorrect_input(value):
raise Exception("The input is incorrect")you can have code raising a specific error like if incorrect_input(value):
raise ValueError("The input is incorrect")or class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect") |
||
| config.n_step = config.async == 'n_step' | ||
| config.actor_critic = config.async == 'a3c' | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,8 +144,8 @@ def variables(self): | |
| def activation_summary(self, tensor): | ||
| if self.write_summaries: | ||
| tensor_name = tensor.op.name | ||
| tf.summary.histogram(tensor_name + '/activations', tensor) | ||
| tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(tensor)) | ||
| tf.summary.histogram(f'{tensor_name}/activations', tensor) | ||
| tf.summary.scalar(f'{tensor_name}/sparsity', tf.nn.zero_fraction(tensor)) | ||
|
Comment on lines
-147
to
+148
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class ActionValueHead(object): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,7 +108,7 @@ def create_summary_ops(self, loss, variables, gradients): | |
|
|
||
| for grad, var in gradients: | ||
| if grad is not None: | ||
| tf.summary.histogram('gradient/' + var.name, grad) | ||
| tf.summary.histogram(f'gradient/{var.name}', grad) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| self.summary.create_summary_op() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -66,9 +66,7 @@ def offset_data(t, name): | |||||||||
| input_len = shape[0] | ||||||||||
| if not hasattr(placeholder, 'zero_offset'): | ||||||||||
| placeholder.zero_offset = tf.placeholder_with_default( | ||||||||||
| input_len - 1, # If no zero_offset is given assume that t = 0 | ||||||||||
| (), | ||||||||||
| name + '/zero_offset') | ||||||||||
| input_len - 1, (), f'{name}/zero_offset') | ||||||||||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ): |
||||||||||
|
|
||||||||||
| end = t + 1 | ||||||||||
| start = end - input_len | ||||||||||
|
|
@@ -100,11 +98,7 @@ def __init__(self, inputs, t): | |||||||||
|
|
||||||||||
| class RequiredFeeds(object): | ||||||||||
| def __init__(self, placeholder=None, time_offsets=0, feeds=None): | ||||||||||
| if feeds: | ||||||||||
| self.feeds = feeds | ||||||||||
| else: | ||||||||||
| self.feeds = {} | ||||||||||
|
|
||||||||||
| self.feeds = feeds if feeds else {} | ||||||||||
|
Comment on lines
-103
to
+101
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: The conditional assignment to self.feeds may not handle falsy but valid feeds. 'feeds if feeds else {}' replaces any falsy value, including valid empty containers, with an empty dict. Use 'feeds if feeds is not None else {}' to preserve valid empty containers.
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): Replace if-expression with
Suggested change
ExplanationHere we find ourselves setting a value if it evaluates toTrue, and otherwiseusing a default. The 'After' case is a bit easier to read and avoids the duplication of It works because the left-hand side is evaluated first. If it evaluates to |
||||||||||
| if placeholder is None: | ||||||||||
| return | ||||||||||
|
|
||||||||||
|
|
@@ -153,21 +147,20 @@ def required_feeds(cls, tensor): | |||||||||
| if hasattr(tensor, 'required_feeds'): | ||||||||||
| # Return cached result | ||||||||||
| return tensor.required_feeds | ||||||||||
| # Get feeds required by all inputs | ||||||||||
| if isinstance(tensor, list): | ||||||||||
| input_tensors = tensor | ||||||||||
| else: | ||||||||||
| # Get feeds required by all inputs | ||||||||||
| if isinstance(tensor, list): | ||||||||||
| input_tensors = tensor | ||||||||||
| else: | ||||||||||
| op = tensor if isinstance(tensor, tf.Operation) else tensor.op | ||||||||||
| input_tensors = list(op.inputs) + list(op.control_inputs) | ||||||||||
| op = tensor if isinstance(tensor, tf.Operation) else tensor.op | ||||||||||
| input_tensors = list(op.inputs) + list(op.control_inputs) | ||||||||||
|
|
||||||||||
| from networks import inputs | ||||||||||
| feeds = inputs.RequiredFeeds() | ||||||||||
| for input_tensor in input_tensors: | ||||||||||
| feeds = feeds.merge(cls.required_feeds(input_tensor)) | ||||||||||
| from networks import inputs | ||||||||||
| feeds = inputs.RequiredFeeds() | ||||||||||
| for input_tensor in input_tensors: | ||||||||||
| feeds = feeds.merge(cls.required_feeds(input_tensor)) | ||||||||||
|
|
||||||||||
| # Cache results | ||||||||||
| if not isinstance(tensor, list): | ||||||||||
| tensor.required_feeds = feeds | ||||||||||
| # Cache results | ||||||||||
| if not isinstance(tensor, list): | ||||||||||
| tensor.required_feeds = feeds | ||||||||||
|
|
||||||||||
| return feeds | ||||||||||
| return feeds | ||||||||||
|
Comment on lines
+150
to
+166
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,10 +31,7 @@ def batch_sigma_squared(self, batch): | |
| self.v = (1 - self.beta) * self.v + self.beta * average_square_reward | ||
|
|
||
| sigma_squared = (self.v - self.mu**2) / self.variance | ||
| if sigma_squared > 0: | ||
| return sigma_squared | ||
| else: | ||
| return 1.0 | ||
| return sigma_squared if sigma_squared > 0 else 1.0 | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| def unnormalize_output(self, output): | ||
| return output * self.scale_weight + self.scale_bias | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,9 +47,8 @@ def test_replay_memory(self): | |
| self.assertAllEqual(feed_dict[inputs.alives], | ||
| [[True, True], [True, False]]) | ||
|
|
||
| discounted_reward = sum([ | ||
| reward * config.discount_rate**(reward - 4) for reward in range(4, 11) | ||
| ]) | ||
| discounted_reward = sum(reward * config.discount_rate**(reward - 4) | ||
| for reward in range(4, 11)) | ||
|
Comment on lines
-50
to
+51
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| self.assertNear( | ||
| feed_dict[inputs.discounted_rewards][0], discounted_reward, err=0.0001) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,10 +27,7 @@ def episode(self, step, score, steps, duration): | |
| self.summary_writer.add_summary(summary, step) | ||
|
|
||
| def operation(self, step): | ||
| if self.run_summary(step): | ||
| return [self.summary_op] | ||
| else: | ||
| return [self.dummy_summary_op] | ||
| return [self.summary_op] if self.run_summary(step) else [self.dummy_summary_op] | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| def add_summary(self, summary, step): | ||
| if summary: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,18 +13,18 @@ def find_previous_run(dir): | |
| if os.path.isdir(dir): | ||
| runs = [child[4:] for child in os.listdir(dir) if child[:4] == 'run_'] | ||
| if runs: | ||
| return max([int(run) for run in runs]) | ||
| return max(int(run) for run in runs) | ||
|
|
||
| return 0 | ||
|
|
||
| if config.run_dir == 'latest': | ||
| parent_dir = 'runs/%s/' % config.game | ||
| parent_dir = f'runs/{config.game}/' | ||
| previous_run = find_previous_run(parent_dir) | ||
| run_dir = parent_dir + ('run_%d' % previous_run) | ||
| elif config.run_dir: | ||
| run_dir = config.run_dir | ||
| else: | ||
| parent_dir = 'runs/%s/' % config.game | ||
| parent_dir = f'runs/{config.game}/' | ||
|
Comment on lines
-16
to
+27
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| previous_run = find_previous_run(parent_dir) | ||
| run_dir = parent_dir + ('run_%d' % (previous_run + 1)) | ||
|
|
||
|
|
@@ -34,18 +34,18 @@ def find_previous_run(dir): | |
| if not os.path.isdir(run_dir): | ||
| os.makedirs(run_dir) | ||
|
|
||
| log('Checkpoint and summary directory is %s' % run_dir) | ||
| log(f'Checkpoint and summary directory is {run_dir}') | ||
|
|
||
| return run_dir | ||
|
|
||
|
|
||
| def format_offset(prefix, t): | ||
| if t > 0: | ||
| return prefix + '_t_plus_' + str(t) | ||
| return f'{prefix}_t_plus_{str(t)}' | ||
| elif t == 0: | ||
| return prefix + '_t' | ||
| return f'{prefix}_t' | ||
| else: | ||
| return prefix + '_t_minus_' + str(-t) | ||
| return f'{prefix}_t_minus_{str(-t)}' | ||
|
Comment on lines
-44
to
+48
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| def add_loss_summaries(total_loss): | ||
|
|
@@ -91,7 +91,7 @@ def log(message): | |
| import threading | ||
| thread_id = threading.current_thread().name | ||
| now = datetime.strftime(datetime.now(), '%F %X') | ||
| print('%s %s: %s' % (now, thread_id, message)) | ||
| print(f'{now} {thread_id}: {message}') | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| def memoize(f): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
Agent.actionrefactored with the following changes:remove-unnecessary-else)This removes the following comments ( why? ):