diff --git a/README.md b/README.md index 30a795d..7d2c558 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ This assumes you have Python 3.6+ installed. 1. Create a virtual env: `python3 -m venv venv` 1. `source venv/bin/activate` 1. Install packages: `pip3 install -r requirements.txt` -1. You can run an example app with this command `python -m rl_bakery.example.cartpole` +1. You can run an example app with this command `python -m rl_bakery.example.cartpole_dqn` To run this on a Mac, you must install java to run locally with pyspark. https://www.java.com/en/download/mac_download.jsp Add 'export JAVA_HOME=/Library/Internet\ Plug-Ins/JavaAppletPlugin.plugin/Contents/Home' to your .bash_profile diff --git a/rl_bakery/applications/simulation_runner.py b/rl_bakery/applications/simulation_runner.py index e5c48c1..8517f68 100644 --- a/rl_bakery/applications/simulation_runner.py +++ b/rl_bakery/applications/simulation_runner.py @@ -96,7 +96,7 @@ def make_batch_tfenv(make_env, config, start_dt, training_interval, spark_sessio :param make_env: A function that returns an Environment :param config: An ApplicationConfig :param start_dt: A datetime being used to simulate the first action - :param training_interval: A datetime indicating the lag between when an observation is generated and when it can + :param training_interval: A timedelta indicating the lag between when an observation is generated and when it can be used for training. This simulates real world environments where there's a delay between data collection and Agent updates. :param spark_session: A Spark session diff --git a/rl_bakery/applications/tfenv/test_tf_env_rl_application.py b/rl_bakery/applications/tfenv/test_tf_env_rl_application.py index 5def5ca..cc98fae 100644 --- a/rl_bakery/applications/tfenv/test_tf_env_rl_application.py +++ b/rl_bakery/applications/tfenv/test_tf_env_rl_application.py @@ -1,8 +1,9 @@ -from datetime import datetime +from datetime import datetime, timedelta from unittest import TestCase from tf_agents.environments import tf_py_environment, suite_gym from rl_bakery.applications.tfenv.indexed_tf_env import IndexedTFEnv from rl_bakery.applications.tfenv.tf_env_rl_application import TFEnvRLApplication +from rl_bakery.spark_utilities import get_spark_session from unittest.mock import patch @@ -26,5 +27,8 @@ def test_init_application(self, mock_dm): steps_num_per_run = 3 - app = TFEnvRLApplication(envs, training_config, steps_num_per_run, datetime.now(), 2) + spark_session = get_spark_session() + app = TFEnvRLApplication(envs, spark_session, training_config, steps_num_per_run, + engine_start_dt=datetime.now(), engine_training_interval=timedelta(days=1), + num_partitions=2) self.assertListEqual(app.obs_cols, ['ob_0', 'ob_1', 'ob_2', 'ob_3'])