-
Notifications
You must be signed in to change notification settings - Fork 40
Reinforcement Learning Module, Part 2 #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bb23237
f8ac4de
72241eb
7ec2f9d
989a423
5c619d2
e55bd0a
75232f6
812565f
5694dcb
6a92d8d
c0b5679
7f8d416
111b6e2
4e0dd09
c15c274
f831a1b
e698c6d
0bf3927
c1b4322
1712eb0
420aad0
18fbf48
9f4b4a2
a802100
68954e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,11 +20,11 @@ | |
|
|
||
| import collections | ||
| import copy | ||
| import logging as log | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's prefer to remove the alias unless there is a namespace collision |
||
| import os | ||
| import time | ||
| from typing import Final, Mapping, NewType, Optional, Sequence, Tuple | ||
| from typing import Final, Mapping, NewType, Optional, Sequence, Tuple, Union | ||
|
|
||
| from absl import logging | ||
| import bidict | ||
| import gin | ||
| import numpy as np | ||
|
|
@@ -86,8 +86,13 @@ | |
| DeviceActionTuple = Tuple[DeviceCode, Setpoint] | ||
| DeviceMeasurementTuple = Tuple[DeviceCode, MeasurementName] | ||
|
|
||
| logger = log.getLogger(__name__) | ||
|
|
||
| def all_actions_accepted(action_response: ActionResponse) -> bool: | ||
| logger = log.getLogger(__name__) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a duplicate logger from line 84? |
||
|
|
||
| def all_actions_accepted( | ||
| action_response: smart_control_building_pb2.ActionResponse, | ||
| ) -> bool: | ||
| """Returns true if all single action requests have response code ACCEPTED.""" | ||
|
|
||
| return all( | ||
|
|
@@ -374,7 +379,6 @@ def __init__( | |
| image_generator: ( | ||
| building_image_generator.BuildingImageGenerator | None | ||
| ) = None, | ||
| step_interval: pd.Timedelta = pd.Timedelta(5, unit="minutes"), | ||
| writer_factory: writer_lib.BaseWriterFactory | None = None, | ||
| ) -> None: | ||
| """Environment constructor. | ||
|
|
@@ -427,10 +431,12 @@ def __init__( | |
| self._end_timestamp: pd.Timestamp = self._start_timestamp + pd.Timedelta( | ||
| num_days_in_episode, unit="days" | ||
| ) | ||
| self._step_interval = step_interval | ||
| self._step_interval = pd.Timedelta(self.building.time_step_sec, unit="s") | ||
| logger.info("Step Interval: %s", self._step_interval) | ||
| self._num_timesteps_in_episode = int( | ||
| (self._end_timestamp - self._start_timestamp) / self._step_interval | ||
| ) | ||
| logger.info("Num Timesteps in Episode: %s", self._num_timesteps_in_episode) | ||
| self._metrics = plot_utils.init_metrics() | ||
| logging.info( | ||
| "Episode starts at %s and ends at %s; % d timesteps.", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| """DDPG Agent implementation. | ||
|
|
||
| This module provides a function to create a DDPG agent with customizable parameters. | ||
| """ | ||
|
|
||
| from typing import Optional, Sequence | ||
|
|
||
| import tensorflow as tf | ||
| from tf_agents.agents import tf_agent | ||
| from tf_agents.agents.ddpg import ddpg_agent | ||
| from tf_agents.networks import network | ||
| from tf_agents.typing import types | ||
|
|
||
| from smart_control.reinforcement_learning.agents.networks.ddpg_networks import create_sequential_actor_network | ||
| from smart_control.reinforcement_learning.agents.networks.ddpg_networks import create_sequential_critic_network | ||
|
|
||
|
|
||
| def create_ddpg_agent( | ||
| time_step_spec: types.TimeStep, | ||
| action_spec: types.NestedTensorSpec, | ||
|
|
||
| # Actor network parameters | ||
| actor_fc_layers: Sequence[int] = (128, 128), | ||
| actor_network: Optional[network.Network] = None, | ||
|
|
||
| # Critic network parameters | ||
| critic_obs_fc_layers: Sequence[int] = (128, 64), | ||
| critic_action_fc_layers: Sequence[int] = (128, 64), | ||
| critic_joint_fc_layers: Sequence[int] = (128, 64), | ||
| critic_network: Optional[network.Network] = None, | ||
|
|
||
| # Optimizer parameters | ||
| actor_learning_rate: float = 3e-4, | ||
| critic_learning_rate: float = 3e-4, | ||
|
|
||
| # Agent parameters | ||
| ou_stddev: float = 1.0, | ||
| ou_damping: float = 1.0, | ||
| gamma: float = 0.99, | ||
| target_update_tau: float = 0.005, | ||
| target_update_period: int = 1, | ||
| reward_scale_factor: float = 1.0, | ||
|
|
||
| # Training parameters | ||
| gradient_clipping: Optional[float] = None, | ||
| debug_summaries: bool = False, | ||
| summarize_grads_and_vars: bool = False, | ||
| train_step_counter: Optional[tf.Variable] = None, | ||
| ) -> tf_agent.TFAgent: | ||
| """Creates a DDPG Agent. | ||
|
|
||
| Args: | ||
| time_step_spec: A `TimeStep` spec of the expected time_steps. | ||
|
|
||
| action_spec: A nest of BoundedTensorSpec representing the actions. | ||
|
|
||
| actor_fc_layers: Iterable of fully connected layer units for the actor network. | ||
|
|
||
| actor_network: Optional custom actor network to use. | ||
|
|
||
| critic_obs_fc_layers: Iterable of fully connected layer units for the critic | ||
| observation network. | ||
|
|
||
| critic_action_fc_layers: Iterable of fully connected layer units for the critic | ||
| action network. | ||
|
|
||
| critic_joint_fc_layers: Iterable of fully connected layer units for the joint | ||
| part of the critic network. | ||
|
|
||
| critic_network: Optional custom critic network to use. | ||
|
|
||
| actor_learning_rate: Actor network learning rate. | ||
|
|
||
| critic_learning_rate: Critic network learning rate. | ||
|
|
||
| ou_stddev: Standard deviation for the Ornstein-Uhlenbeck (OU) noise added for | ||
| exploration. | ||
|
|
||
| ou_damping: Damping factor for the OU noise. | ||
|
|
||
| gamma: Discount factor for future rewards. | ||
|
|
||
| target_update_tau: Factor for soft update of target networks. | ||
|
|
||
| target_update_period: Period for soft update of target networks. | ||
|
|
||
| reward_scale_factor: Multiplicative scale for the reward. | ||
|
|
||
| gradient_clipping: Norm length to clip gradients. | ||
|
|
||
| debug_summaries: Whether to emit debug summaries. | ||
|
|
||
| summarize_grads_and_vars: Whether to summarize gradients and variables. | ||
|
|
||
| train_step_counter: An optional counter to increment every time the train | ||
| op is run. Defaults to the global_step. | ||
|
|
||
| Returns: | ||
| A TFAgent instance with the DDPG agent. | ||
| """ | ||
| # Create train step counter if not provided | ||
| if train_step_counter is None: | ||
| train_step_counter = tf.Variable(0, trainable=False, dtype=tf.int64) | ||
|
|
||
| # Create networks if not provided | ||
| if actor_network is None: | ||
| actor_network = create_sequential_actor_network( | ||
| actor_fc_layers=actor_fc_layers, | ||
| action_tensor_spec=action_spec | ||
| ) | ||
|
|
||
| if critic_network is None: | ||
| critic_network = create_sequential_critic_network( | ||
| obs_fc_layer_units=critic_obs_fc_layers, | ||
| action_fc_layer_units=critic_action_fc_layers, | ||
| joint_fc_layer_units=critic_joint_fc_layers | ||
| ) | ||
|
|
||
| # Create agent | ||
| tf_agent = ddpg_agent.DdpgAgent( | ||
| time_step_spec=time_step_spec, | ||
| action_spec=action_spec, | ||
| actor_network=actor_network, | ||
| critic_network=critic_network, | ||
| actor_optimizer=tf.keras.optimizers.Adam(learning_rate=actor_learning_rate), | ||
| critic_optimizer=tf.keras.optimizers.Adam(learning_rate=critic_learning_rate), | ||
| ou_stddev=ou_stddev, | ||
| ou_damping=ou_damping, | ||
| target_update_tau=target_update_tau, | ||
| target_update_period=target_update_period, | ||
| td_errors_loss_fn=tf.math.squared_difference, | ||
| gamma=gamma, | ||
| reward_scale_factor=reward_scale_factor, | ||
| gradient_clipping=gradient_clipping, | ||
| debug_summaries=debug_summaries, | ||
| summarize_grads_and_vars=summarize_grads_and_vars, | ||
| train_step_counter=train_step_counter | ||
| ) | ||
|
|
||
| # Initialize the agent | ||
| tf_agent.initialize() | ||
|
|
||
| return tf_agent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to re-add tqdm to the toml.