diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..507e2f8 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.ipynb filter=strip-notebook-output \ No newline at end of file diff --git a/cql/cql.ipynb b/cql/cql.ipynb new file mode 100644 index 0000000..9416ba7 --- /dev/null +++ b/cql/cql.ipynb @@ -0,0 +1,665 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import gymnasium as gym\n", + "import dill\n", + "\n", + "import sys\n", + "sys.path.append(\"C:/Users/aukes/Documents/Code/MSc Computer Science/CS4210-B Intelligent Decision Making Project/offline_multi_task_rl/\")\n", + "import os\n", + "os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'\n", + "\n", + "from four_room.env import FourRoomsEnv\n", + "from four_room.wrappers import gym_wrapper\n", + "from four_room.shortest_path import find_all_action_values\n", + "from four_room.utils import obs_to_state\n", + "import numpy as np\n", + "import d3rlpy\n", + "from d3rlpy.algos import DiscreteCQLConfig\n", + "import wandb\n", + "import pickle\n", + "\n", + "gym.register('MiniGrid-FourRooms-v1', FourRoomsEnv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_config_path = '../four_room/configs/fourrooms_train_config.pl'\n", + "reachable_test_config_path = '../four_room/configs/fourrooms_test_100_config.pl'\n", + "unreachable_test_config_path = '../four_room/configs/fourrooms_test_0_config.pl'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_config(path):\n", + " with open(path, 'rb') as file:\n", + " train_config = dill.load(file)\n", + " file.close()\n", + " return train_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "'''\n", + "Generates a dataset from the tasks specified in config. Size of returned dataset thus depends on amount of tasks\n", + "specified in config as well as on the quality of the policy used to generate the dataset. If step_limit=True is\n", + "used as argument the generation of data samples is stopped after num_steps steps. If all task in config are\n", + "completed before num_steps a smaller dataset is returned. The policy argument takes an int, where 0=expert,\n", + "1=random.\n", + "'''\n", + "def get_dataset_from_config(config, policy=0, step_limit=False, num_steps=1000):\n", + " env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1',\n", + " agent_pos=config['agent positions'],\n", + " goal_pos=config['goal positions'],\n", + " doors_pos=config['topologies'],\n", + " agent_dir=config['agent directions']))\n", + "\n", + " num_of_tasks = len(config['agent positions'])\n", + " tasks_seen = 0\n", + "\n", + " observation, info = env.reset()\n", + " tasks_seen += 1\n", + "\n", + " dataset = {'observations': [], 'next_observations': [], 'actions': [], 'rewards': [],\n", + " 'terminals': [], 'timeouts': [], 'infos': []}\n", + "\n", + " count = 0\n", + " while not (step_limit and num_steps < count): # stops if there is a step_limit and it is exceeded\n", + " count += 1\n", + "\n", + " if policy == 0:\n", + " state = obs_to_state(observation)\n", + " q_values = find_all_action_values(state[:2], state[2], state[3:5], state[5:], 0.99)\n", + " action = np.argmax(q_values)\n", + " elif policy == 1:\n", + " action = env.action_space.sample()\n", + " else:\n", + " # implement default behaviour or return error, for now just uses random policy\n", + " action = env.action_space.sample()\n", + "\n", + " last_observation = observation\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + "\n", + " dataset['observations'].append(np.array(last_observation).flatten())\n", + " dataset['next_observations'].append(np.array(observation).flatten())\n", + " dataset['actions'].append(np.array([action]))\n", + " dataset['rewards'].append(reward)\n", + " dataset['terminals'].append(terminated)\n", + " dataset['timeouts'].append(truncated)\n", + " dataset['infos'].append(info)\n", + "\n", + " if terminated or truncated:\n", + " if tasks_seen == num_of_tasks:\n", + " env.close()\n", + " break\n", + " observation, info = env.reset()\n", + " tasks_seen += 1\n", + "\n", + " for key in dataset:\n", + " dataset[key] = np.array(dataset[key])\n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_expert_dataset_from_config(config, step_limit=False, num_steps=1000):\n", + " return get_dataset_from_config(config, policy=0, step_limit=step_limit, num_steps=num_steps)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_random_dataset_from_config(config, step_limit=False, num_steps=1000):\n", + " return get_dataset_from_config(config, policy=1, step_limit=step_limit, num_steps=num_steps)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_config = get_config(train_config_path)\n", + "train_dataset_expert = get_expert_dataset_from_config(train_config)\n", + "train_dataset_expert = d3rlpy.dataset.MDPDataset(\n", + " observations=train_dataset_expert.get(\"observations\"),\n", + " actions=train_dataset_expert.get(\"actions\"),\n", + " rewards=train_dataset_expert.get(\"rewards\"),\n", + " terminals=train_dataset_expert.get(\"terminals\"),\n", + ")\n", + "\n", + "reachable_test_config = get_config(reachable_test_config_path)\n", + "reachable_test_dataset = get_expert_dataset_from_config(reachable_test_config)\n", + "reachable_test_dataset = d3rlpy.dataset.MDPDataset(\n", + " observations=reachable_test_dataset.get(\"observations\"),\n", + " actions=reachable_test_dataset.get(\"actions\"),\n", + " rewards=reachable_test_dataset.get(\"rewards\"),\n", + " terminals=reachable_test_dataset.get(\"terminals\"),\n", + ")\n", + "\n", + "unreachable_test_config = get_config(unreachable_test_config_path)\n", + "unreachable_test_dataset = get_expert_dataset_from_config(unreachable_test_config)\n", + "unreachable_test_dataset = d3rlpy.dataset.MDPDataset(\n", + " observations=unreachable_test_dataset.get(\"observations\"),\n", + " actions=unreachable_test_dataset.get(\"actions\"),\n", + " rewards=unreachable_test_dataset.get(\"rewards\"),\n", + " terminals=unreachable_test_dataset.get(\"terminals\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# config = get_config(train_config_path)\n", + "\n", + "# total_tasks = 0\n", + "# total_steps = 0\n", + "\n", + "# for i in range(10):\n", + "# dataset = get_random_dataset_from_config(config)\n", + "\n", + "# num_tasks = len(config['agent positions'])\n", + "# num_steps = len(dataset['observations'])\n", + "\n", + "# total_tasks += num_tasks\n", + "# total_steps += num_steps\n", + "\n", + "# print('#tasks = ', total_tasks)\n", + "# print('#steps = ', total_steps)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### Setup and Train algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "# setup algorithm\n", + "cql = d3rlpy.algos.DiscreteCQLConfig(learning_rate=1e-4).create()\n", + "\n", + "# start offline training\n", + "cql.fit(train_dataset_mixed,\n", + " # evaluators={\"env_eval\": env_eval},\n", + " n_steps=10000,\n", + " n_steps_per_epoch=500,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### Evaluate the trained algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_on_environment(model, env, max_steps, num_tasks, verbose=0):\n", + " total_reward = 0\n", + " terminated, truncated = False, False\n", + " num_terminated = 0\n", + " num_truncated = 0\n", + "\n", + " observation, _ = env.reset()\n", + " observation = np.expand_dims(observation.flatten(), axis=0)\n", + "\n", + " tasks_done = 0\n", + " steps_done = 0\n", + " while tasks_done < num_tasks and steps_done < max_steps:\n", + " # Run the model on the environment and collect the rewards\n", + " action = model.predict(observation)[0]\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + " observation = np.expand_dims(observation.flatten(), axis=0)\n", + " steps_done += 1\n", + " total_reward += reward\n", + "\n", + " if terminated or truncated:\n", + " if verbose: print(f\"Tasks done: {tasks_done+1}/{num_tasks}\", end=\"\\r\")\n", + " if terminated: num_terminated += 1\n", + " if truncated: num_truncated += 1\n", + " tasks_done += 1\n", + " if tasks_done == num_tasks:\n", + " break\n", + " observation, _ = env.reset()\n", + " observation = np.expand_dims(observation.flatten(), axis=0)\n", + " if verbose:\n", + " print(\"\\nTotal reward: \", total_reward,\n", + " \"\\nTotal steps: \", steps_done,\n", + " \"\\nTasks done: \", tasks_done,\n", + " \"\\nTerminated: \", num_terminated,\n", + " \"\\nTruncated: \", num_truncated)\n", + " return total_reward, steps_done, tasks_done, num_terminated, num_truncated" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_training_and_test_performance_over_time(log_path, eval_every_n, max_steps, train_config_path, reachable_test_config_path, unreachable_test_config_path):\n", + " performances_train = {}\n", + " performances_reachable = {}\n", + " performances_unreachable = {}\n", + "\n", + " for i in range(eval_every_n, max_steps, eval_every_n):\n", + " model = d3rlpy.load_learnable(f\"{log_path}/model_{i}.d3\")\n", + " train_config = get_config(train_config_path)\n", + " train_env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1',\n", + " agent_pos=train_config['agent positions'],\n", + " goal_pos=train_config['goal positions'],\n", + " doors_pos=train_config['topologies'],\n", + " agent_dir=train_config['agent directions']))\n", + " reachable_test_config = get_config(reachable_test_config_path)\n", + " reachable_test_env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1',\n", + " agent_pos=reachable_test_config['agent positions'],\n", + " goal_pos=reachable_test_config['goal positions'],\n", + " doors_pos=reachable_test_config['topologies'],\n", + " agent_dir=reachable_test_config['agent directions']))\n", + " unreachable_test_config = get_config(unreachable_test_config_path)\n", + " unreachable_test_env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1',\n", + " agent_pos=unreachable_test_config['agent positions'],\n", + " goal_pos=unreachable_test_config['goal positions'],\n", + " doors_pos=unreachable_test_config['topologies'],\n", + " agent_dir=unreachable_test_config['agent directions'])\n", + " )\n", + "\n", + " total_reward, steps_done, tasks_done, num_terminated, num_truncated = evaluate_on_environment(model, train_env, 1000000, 40, verbose=0)\n", + " performances_train[str(i)] = (total_reward, steps_done, tasks_done, num_terminated, num_truncated)\n", + "\n", + " total_reward, steps_done, tasks_done, num_terminated, num_truncated = evaluate_on_environment(model, reachable_test_env, 1000000, 40, verbose=0)\n", + " performances_reachable[str(i)] = (total_reward, steps_done, tasks_done, num_terminated, num_truncated)\n", + "\n", + " total_reward, steps_done, tasks_done, num_terminated, num_truncated = evaluate_on_environment(model, unreachable_test_env, 1000000, 40, verbose=0)\n", + " performances_unreachable[str(i)] = (total_reward, steps_done, tasks_done, num_terminated, num_truncated)\n", + " print(f\"Progress: {i}/{max_steps}\", end=\"\\r\")\n", + " return performances_train, performances_reachable, performances_unreachable" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_performances(performances_train, performances_reachable, performances_unreachable):\n", + " # Extract training steps and total rewards for each dictionary\n", + " training_steps_train = [int(key) for key in performances_train.keys()]\n", + " total_rewards_train = [value[0] for value in performances_train.values()]\n", + "\n", + " training_steps_reachable = [int(key) for key in performances_reachable.keys()]\n", + " total_rewards_reachable = [value[0] for value in performances_reachable.values()]\n", + "\n", + " training_steps_unreachable = [int(key) for key in performances_unreachable.keys()]\n", + " total_rewards_unreachable = [value[0] for value in performances_unreachable.values()]\n", + "\n", + " # Create the plot\n", + " plt.plot(training_steps_train, total_rewards_train, label='Train')\n", + " plt.plot(training_steps_reachable, total_rewards_reachable, label='Reachable')\n", + " plt.plot(training_steps_unreachable, total_rewards_unreachable, label='Unreachable')\n", + "\n", + " plt.xlabel('Training Steps')\n", + " plt.ylabel('Total Reward')\n", + " plt.title('Total Reward over Training Steps')\n", + " plt.legend()\n", + " plt.grid(True)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "log_path = \"d3rlpy_logs/DiscreteCQL_20240604143009\"\n", + "eval_every_n = 500\n", + "max_steps = 10000\n", + "performances_train, performances_reachable, performances_unreachable = get_training_and_test_performance_over_time(log_path, eval_every_n, max_steps, train_config_path, reachable_test_config_path, unreachable_test_config_path)\n", + "plot_performances(performances_train, performances_reachable, performances_unreachable)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### Hyperparameter tuning\n", + "##### Hyperparameters to look at:\n", + "- learning_rate\n", + "- batch_size\n", + "- discount_factor\n", + "- alpha" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_with_hyperparameters(learning_rate, batch_size, discount_factor, alpha, verbose=0):\n", + " # Define the CQL algorithm with suggested hyperparameters\n", + " cql_config = DiscreteCQLConfig(\n", + " learning_rate=learning_rate,\n", + " batch_size=batch_size,\n", + " gamma=discount_factor,\n", + " alpha=alpha\n", + " )\n", + " cql = cql_config.create()\n", + "\n", + " # Train the model\n", + " cql.fit(train_dataset,\n", + " n_steps=20000,\n", + " n_steps_per_epoch=500)\n", + "\n", + " return cql\n", + "\n", + "# Takes hyperparameters_scores dictionary as input so that if the function is interrupted the attained scores are not lost.\n", + "def hyperparameter_tuning(learning_rates, batch_sizes, discount_factors, alphas, hyperparameters_scores=None, verbose=0):\n", + " if hyperparameters_scores is None:\n", + " hyperparameters_scores = {}\n", + "\n", + " if hyperparameters_scores is None:\n", + " hyperparameters_scores = {}\n", + " if verbose: print(\"Starting hyperparameter tuning...\")\n", + " for learning_rate in learning_rates:\n", + " for batch_size in batch_sizes:\n", + " for discount_factor in discount_factors:\n", + " for alpha in alphas:\n", + " train_env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1',\n", + " agent_pos=train_config['agent positions'],\n", + " goal_pos=train_config['goal positions'],\n", + " doors_pos=train_config['topologies'],\n", + " agent_dir=train_config['agent directions']))\n", + " cql = train_with_hyperparameters(learning_rate, batch_size, discount_factor, alpha)\n", + " total_reward, steps_done, tasks_done, num_terminated, num_truncated = evaluate_on_environment(cql, train_env, 1000000, 40, verbose)\n", + " hyperparameters_scores[(learning_rate, batch_size, discount_factor, alpha)] = (total_reward, steps_done, tasks_done, num_terminated, num_truncated)\n", + " if verbose: print(f\"Progress: {len(hyperparameters_scores)}/{len(learning_rates)*len(batch_sizes)*len(discount_factors)*len(alphas)}\", end=\"\\r\")\n", + " if verbose: print(\"\\nHyperparameter tuning done.\")\n", + "\n", + " best_hyperparameters = None\n", + " for hyperparameters in hyperparameters_scores:\n", + " if best_hyperparameters is None or hyperparameters_scores[hyperparameters][0] > hyperparameters_scores[best_hyperparameters][0]:\n", + " best_hyperparameters = hyperparameters\n", + " print(f\"Best hyperparameters (based on total reward): {best_hyperparameters}\\n\",\n", + " f\"Scores: {hyperparameters_scores[best_hyperparameters]}\")\n", + "\n", + " return hyperparameters_scores" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "learning_rates = [1e-6, 1e-5, 1e-4]\n", + "batch_sizes = [64, 128, 256]\n", + "discount_factors = [0.98, 0.99, 0.999]\n", + "alphas = [0.1, 1.0, 10.0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "learning_rates = [1e-4]\n", + "batch_sizes = [64, 128, 256]\n", + "discount_factors = [0.98, 0.99, 0.999]\n", + "alphas = [0.1, 1.0, 10.0]\n", + "hyperparameters_scores = {}\n", + "hyperparameters_scores = hyperparameter_tuning(learning_rates, batch_sizes, discount_factors, alphas, hyperparameters_scores, verbose=1)\n", + "print(hyperparameters_scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### From analysing the hyperparameters scores we can see that when the learning rate is 1e-4 the model achieves the highes possible performance the most reliably. When the learning rate is 1e-04 it achieves optimal performance (same as expert policy) on the training set regardless of the other hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hyperparameters_scores = {\n", + " (1e-06, 64, 0.98, 0.1): (0, 4000, 40, 0, 40),\n", + " (1e-06, 64, 0.98, 1.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.98, 10.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.99, 0.1): (0, 4000, 40, 0, 40), (1e-06, 64, 0.99, 1.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.99, 10.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.999, 0.1): (0, 4000, 40, 0, 40), (1e-06, 64, 0.999, 1.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.999, 10.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.98, 0.1): (1, 3906, 40, 1, 39), (1e-06, 128, 0.98, 1.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.98, 10.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.99, 0.1): (1, 3906, 40, 1, 39), (1e-06, 128, 0.99, 1.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.99, 10.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.999, 0.1): (1, 3906, 40, 1, 39), (1e-06, 128, 0.999, 1.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.999, 10.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.98, 0.1): (0, 4000, 40, 0, 40), (1e-06, 256, 0.98, 1.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.98, 10.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.99, 0.1): (0, 4000, 40, 0, 40), (1e-06, 256, 0.99, 1.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.99, 10.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.999, 0.1): (0, 4000, 40, 0, 40), (1e-06, 256, 0.999, 1.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.999, 10.0): (0, 4000, 40, 0, 40), (1e-05, 64, 0.98, 0.1): (35, 838, 40, 35, 5), (1e-05, 64, 0.98, 1.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.98, 10.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.99, 0.1): (33, 1008, 40, 33, 7), (1e-05, 64, 0.99, 1.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.99, 10.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.999, 0.1): (38, 556, 40, 38, 2), (1e-05, 64, 0.999, 1.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.999, 10.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.98, 0.1): (40, 372, 40, 40, 0), (1e-05, 128, 0.98, 1.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.98, 10.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.99, 0.1): (39, 460, 40, 39, 1), (1e-05, 128, 0.99, 1.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.99, 10.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.999, 0.1): (39, 460, 40, 39, 1), (1e-05, 128, 0.999, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.98, 0.1): (40, 372, 40, 40, 0), (1e-05, 256, 0.98, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.98, 10.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.99, 0.1): (40, 372, 40, 40, 0), (1e-05, 256, 0.99, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.99, 10.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.999, 0.1): (40, 372, 40, 40, 0), (1e-05, 256, 0.999, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.999, 10.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.98, 0.1): (40, 372, 40, 40, 0), (0.0001, 64, 0.98, 1.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.98, 10.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.99, 0.1): (40, 372, 40, 40, 0), (0.0001, 64, 0.99, 1.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.99, 10.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.999, 0.1): (40, 372, 40, 40, 0), (0.0001, 64, 0.999, 1.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.999, 10.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.98, 0.1): (40, 372, 40, 40, 0), (0.0001, 128, 0.98, 1.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.98, 10.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.99, 0.1): (40, 372, 40, 40, 0), (0.0001, 128, 0.99, 1.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.99, 10.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.999, 0.1): (40, 372, 40, 40, 0), (0.0001, 128, 0.999, 1.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.999, 10.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.98, 0.1): (40, 372, 40, 40, 0), (0.0001, 256, 0.98, 1.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.98, 10.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.99, 0.1): (40, 372, 40, 40, 0), (0.0001, 256, 0.99, 1.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.99, 10.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.999, 0.1): (40, 372, 40, 40, 0), (0.0001, 256, 0.999, 1.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.999, 10.0): (40, 372, 40, 40, 0)}\n", + "optimal_params = []\n", + "\n", + "learning_rate_e6 = 0\n", + "learning_rate_e5 = 0\n", + "learning_rate_e4 = 0\n", + "\n", + "batch_size_64 = 0\n", + "batch_size_128 = 0\n", + "batch_size_256 = 0\n", + "\n", + "discount_factor_98 = 0\n", + "discount_factor_99 = 0\n", + "discount_factor_999 = 0\n", + "\n", + "alpha_01 = 0\n", + "alpha_1 = 0\n", + "alpha_10 = 0\n", + "\n", + "for a,b,c,d in hyperparameters_scores:\n", + " e,f,g,h,i = hyperparameters_scores[(a,b,c,d)]\n", + " least_steps = 372\n", + " # if not a == 1e-04:\n", + " # continue\n", + " if e == 40 and f == 372:\n", + " if a == 1e-04: learning_rate_e4 += 1\n", + " if a == 1e-05: learning_rate_e5 += 1\n", + " if a == 1e-06: learning_rate_e6 += 1\n", + "\n", + " if b == 64: batch_size_64 += 1\n", + " if b == 128: batch_size_128 += 1\n", + " if b == 256: batch_size_256 += 1\n", + "\n", + " if c == 0.98: discount_factor_98 += 1\n", + " if c == 0.99: discount_factor_99 += 1\n", + " if c == 0.999: discount_factor_999 += 1\n", + "\n", + " if d == 0.1: alpha_01 += 1\n", + " if d == 1.0: alpha_1 += 1\n", + " if d == 10.0: alpha_10 += 1\n", + "\n", + " optimal_params.append((a,b,c,d))\n", + "print(\"learning_rate_e6: \", learning_rate_e6)\n", + "print(\"learning_rate_e5: \", learning_rate_e5)\n", + "print(\"learning_rate_e4: \", learning_rate_e4)\n", + "\n", + "print(\"batch_size_64: \", batch_size_64)\n", + "print(\"batch_size_128: \", batch_size_128)\n", + "print(\"batch_size_256: \", batch_size_256)\n", + "\n", + "print(\"discount_factor_98: \", discount_factor_98)\n", + "print(\"discount_factor_99: \", discount_factor_99)\n", + "print(\"discount_factor_999: \", discount_factor_999)\n", + "\n", + "print(\"alpha_01: \", alpha_01)\n", + "print(\"alpha_1: \", alpha_1)\n", + "print(\"alpha_10: \", alpha_10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for a,b,c,d in hyperparameters_scores:\n", + " e,f,g,h,i = hyperparameters_scores[(a,b,c,d)]\n", + " # total_reward, steps_done, tasks_done, num_terminated, num_truncated\n", + " # print(a, b, c, d, \"\\t\", e, f, h, i)\n", + "print(hyperparameters_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = get_expert_dataset_from_config(train_config)\n", + "train_dataset = d3rlpy.dataset.MDPDataset(\n", + " observations=train_dataset.get(\"observations\"),\n", + " actions=train_dataset.get(\"actions\"),\n", + " rewards=train_dataset.get(\"rewards\"),\n", + " terminals=train_dataset.get(\"terminals\"),\n", + ")\n", + "train_dataset.dataset_info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiments on mixed datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_config = get_config(train_config_path)\n", + "train_dataset_expert = get_expert_dataset_from_config(train_config)\n", + "train_dataset_expert = d3rlpy.dataset.MDPDataset(\n", + " observations=train_dataset_expert.get(\"observations\"),\n", + " actions=train_dataset_expert.get(\"actions\"),\n", + " rewards=train_dataset_expert.get(\"rewards\"),\n", + " terminals=train_dataset_expert.get(\"terminals\"),\n", + ")\n", + "\n", + "dataset_path = \"../four_room_extensions/datasets/dataset_from_models_300000_350000_390000_450000_470000.pkl\"\n", + "with open(dataset_path, \"rb\") as f:\n", + " train_dataset_mixed = pickle.load(f)\n", + "train_dataset_mixed = d3rlpy.dataset.MDPDataset(\n", + " observations=train_dataset_mixed['observations'],\n", + " actions=train_dataset_mixed['actions'],\n", + " rewards=train_dataset_mixed['rewards'],\n", + " terminals=train_dataset_mixed['terminals'],\n", + ")\n", + "\n", + "# setup algorithm\n", + "cql = d3rlpy.algos.DiscreteCQLConfig(learning_rate=1e-4).create()\n", + "\n", + "# start offline training\n", + "cql.fit(train_dataset_expert,\n", + " n_steps=20000,\n", + " n_steps_per_epoch=500,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "log_path = \"d3rlpy_logs/DiscreteCQL_20240604150844\"\n", + "eval_every_n = 500\n", + "max_steps = 20000\n", + "performances_train, performances_reachable, performances_unreachable = get_training_and_test_performance_over_time(log_path, eval_every_n, max_steps, train_config_path, reachable_test_config_path, unreachable_test_config_path)\n", + "plot_performances(performances_train, performances_reachable, performances_unreachable)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "log_path = \"d3rlpy_logs/DiscreteCQL_20240604162814\"\n", + "eval_every_n = 500\n", + "max_steps = 20000\n", + "performances_train, performances_reachable, performances_unreachable = get_training_and_test_performance_over_time(log_path, eval_every_n, max_steps, train_config_path, reachable_test_config_path, unreachable_test_config_path)\n", + "plot_performances(performances_train, performances_reachable, performances_unreachable)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/cql/gen_dataset.py b/cql/gen_dataset.py new file mode 100644 index 0000000..a52ee1a --- /dev/null +++ b/cql/gen_dataset.py @@ -0,0 +1,91 @@ +import gymnasium as gym +import dill +import numpy as np +import imageio +from typing import Dict, Any +from stable_baselines3.dqn.dqn import DQN +import pickle + +import sys +sys.path.append("C:/Users/aukes/Documents/Code/MSc Computer Science/CS4210-B Intelligent Decision Making Project/offline_multi_task_rl/") + +from four_room.env import FourRoomsEnv +from four_room.wrappers import gym_wrapper +import os + +gym.register('MiniGrid-FourRooms-v1', FourRoomsEnv) + +def get_config(path): + with open(path, 'rb') as file: + train_config = dill.load(file) + file.close() + return train_config + +def get_mixed_dataset_from_config(config, models=[300000, 350000, 390000, 450000, 470000], render=False, render_name="") -> tuple[Dict[str, Any], gym.Env]: + ''' + Generates a dataset from multiple policies on the tasks specified in config. Size of returned dataset thus + depends on amount of tasks specified in config as well as on the quality of the policies used to generate the + dataset. If step_limit=True is used as argument the generation of data samples is stopped after num_steps steps. + If all task in config are completed before num_steps a smaller dataset is returned. The integers in models array + should point to a pretrained model (a.k.a. policy) that can be loaded from the DQN_models folder. + ''' + gym.register('MiniGrid-FourRooms-v1', FourRoomsEnv) + env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1', + agent_pos=config['agent positions'], + goal_pos=config['goal positions'], + doors_pos=config['topologies'], + agent_dir=config['agent directions'], + render_mode="rgb_array")) + + tasks_finished = 0 + tasks_failed = 0 + + dataset = {'observations': [], 'next_observations': [], 'actions': [], 'rewards': [], + 'terminals': [], 'timeouts': [], 'infos': []} + + imgs = [] + for idx, m in enumerate(models): + model = DQN.load(f"four_room_extensions/DQN_models/DQN_{m}.zip") + for i in range(len(config["topologies"])): + observation, _ = env.reset() + done = False + while not done: + imgs.append(env.render()) if render else None + + action, _ = model.predict(observation) + + last_observation = observation + observation, reward, terminated, truncated, info = env.step(action) + + dataset['observations'].append(np.array(last_observation).flatten()) + dataset['next_observations'].append(np.array(observation).flatten()) + dataset['actions'].append(np.array([action])) + dataset['rewards'].append(reward) + dataset['terminals'].append(terminated) + dataset['timeouts'].append(truncated) + dataset['infos'].append(info) + + if terminated: + tasks_finished += 1 + if truncated: + tasks_failed += 1 + done = terminated or truncated + print(f"progress: {idx+1}/{len(models)}") + + for key in dataset: + dataset[key] = np.array(dataset[key]) + + render_name_extension = '_'.join(map(str, models)) + render_name = f"{render_name}" if render_name else f'rendered_episode_{render_name_extension}' + imageio.mimsave(f'rendered_episodes/{render_name}.gif', [np.array(img) for i, img in enumerate(imgs) if i%1 == 0], duration=200) if render else None + + return dataset, env + + +config = get_config("four_room/configs/fourrooms_train_config.pl") +models = [300000, 350000, 390000, 450000, 470000] +dataset, env = get_mixed_dataset_from_config(config, models) + +dataset_file_name = "four_room_extensions/datasets/dataset_from_models_" + '_'.join(map(str, models)) + ".pkl" +with open(dataset_file_name, 'wb') as f: + pickle.dump(dataset, f) \ No newline at end of file diff --git a/cql/hyperparameter_tuning_attained_scores.txt b/cql/hyperparameter_tuning_attained_scores.txt new file mode 100644 index 0000000..6dee7f3 --- /dev/null +++ b/cql/hyperparameter_tuning_attained_scores.txt @@ -0,0 +1 @@ +{(1e-06, 64, 0.98, 0.1): (0, 4000, 40, 0, 40), (1e-06, 64, 0.98, 1.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.98, 10.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.99, 0.1): (0, 4000, 40, 0, 40), (1e-06, 64, 0.99, 1.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.99, 10.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.999, 0.1): (0, 4000, 40, 0, 40), (1e-06, 64, 0.999, 1.0): (0, 4000, 40, 0, 40), (1e-06, 64, 0.999, 10.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.98, 0.1): (1, 3906, 40, 1, 39), (1e-06, 128, 0.98, 1.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.98, 10.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.99, 0.1): (1, 3906, 40, 1, 39), (1e-06, 128, 0.99, 1.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.99, 10.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.999, 0.1): (1, 3906, 40, 1, 39), (1e-06, 128, 0.999, 1.0): (0, 4000, 40, 0, 40), (1e-06, 128, 0.999, 10.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.98, 0.1): (0, 4000, 40, 0, 40), (1e-06, 256, 0.98, 1.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.98, 10.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.99, 0.1): (0, 4000, 40, 0, 40), (1e-06, 256, 0.99, 1.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.99, 10.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.999, 0.1): (0, 4000, 40, 0, 40), (1e-06, 256, 0.999, 1.0): (0, 4000, 40, 0, 40), (1e-06, 256, 0.999, 10.0): (0, 4000, 40, 0, 40), (1e-05, 64, 0.98, 0.1): (35, 838, 40, 35, 5), (1e-05, 64, 0.98, 1.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.98, 10.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.99, 0.1): (33, 1008, 40, 33, 7), (1e-05, 64, 0.99, 1.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.99, 10.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.999, 0.1): (38, 556, 40, 38, 2), (1e-05, 64, 0.999, 1.0): (40, 372, 40, 40, 0), (1e-05, 64, 0.999, 10.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.98, 0.1): (40, 372, 40, 40, 0), (1e-05, 128, 0.98, 1.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.98, 10.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.99, 0.1): (39, 460, 40, 39, 1), (1e-05, 128, 0.99, 1.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.99, 10.0): (40, 372, 40, 40, 0), (1e-05, 128, 0.999, 0.1): (39, 460, 40, 39, 1), (1e-05, 128, 0.999, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.98, 0.1): (40, 372, 40, 40, 0), (1e-05, 256, 0.98, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.98, 10.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.99, 0.1): (40, 372, 40, 40, 0), (1e-05, 256, 0.99, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.99, 10.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.999, 0.1): (40, 372, 40, 40, 0), (1e-05, 256, 0.999, 1.0): (40, 372, 40, 40, 0), (1e-05, 256, 0.999, 10.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.98, 0.1): (40, 372, 40, 40, 0), (0.0001, 64, 0.98, 1.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.98, 10.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.99, 0.1): (40, 372, 40, 40, 0), (0.0001, 64, 0.99, 1.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.99, 10.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.999, 0.1): (40, 372, 40, 40, 0), (0.0001, 64, 0.999, 1.0): (40, 372, 40, 40, 0), (0.0001, 64, 0.999, 10.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.98, 0.1): (40, 372, 40, 40, 0), (0.0001, 128, 0.98, 1.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.98, 10.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.99, 0.1): (40, 372, 40, 40, 0), (0.0001, 128, 0.99, 1.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.99, 10.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.999, 0.1): (40, 372, 40, 40, 0), (0.0001, 128, 0.999, 1.0): (40, 372, 40, 40, 0), (0.0001, 128, 0.999, 10.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.98, 0.1): (40, 372, 40, 40, 0), (0.0001, 256, 0.98, 1.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.98, 10.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.99, 0.1): (40, 372, 40, 40, 0), (0.0001, 256, 0.99, 1.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.99, 10.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.999, 0.1): (40, 372, 40, 40, 0), (0.0001, 256, 0.999, 1.0): (40, 372, 40, 40, 0), (0.0001, 256, 0.999, 10.0): (40, 372, 40, 40, 0)} \ No newline at end of file diff --git a/four_room_extensions/DQN_models/DQN_10000.zip b/four_room_extensions/DQN_models/DQN_10000.zip new file mode 100644 index 0000000..88301b7 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_10000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_100000.zip b/four_room_extensions/DQN_models/DQN_100000.zip new file mode 100644 index 0000000..7bcafea Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_100000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_110000.zip b/four_room_extensions/DQN_models/DQN_110000.zip new file mode 100644 index 0000000..ae2ab21 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_110000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_120000.zip b/four_room_extensions/DQN_models/DQN_120000.zip new file mode 100644 index 0000000..56b30cd Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_120000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_130000.zip b/four_room_extensions/DQN_models/DQN_130000.zip new file mode 100644 index 0000000..d24eaee Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_130000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_140000.zip b/four_room_extensions/DQN_models/DQN_140000.zip new file mode 100644 index 0000000..29838a3 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_140000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_150000.zip b/four_room_extensions/DQN_models/DQN_150000.zip new file mode 100644 index 0000000..322e4e9 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_150000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_160000.zip b/four_room_extensions/DQN_models/DQN_160000.zip new file mode 100644 index 0000000..79a65e3 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_160000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_170000.zip b/four_room_extensions/DQN_models/DQN_170000.zip new file mode 100644 index 0000000..96fd7d6 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_170000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_180000.zip b/four_room_extensions/DQN_models/DQN_180000.zip new file mode 100644 index 0000000..f1535ea Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_180000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_190000.zip b/four_room_extensions/DQN_models/DQN_190000.zip new file mode 100644 index 0000000..d5f8a63 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_190000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_20000.zip b/four_room_extensions/DQN_models/DQN_20000.zip new file mode 100644 index 0000000..c75ddf0 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_20000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_200000.zip b/four_room_extensions/DQN_models/DQN_200000.zip new file mode 100644 index 0000000..7e0297d Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_200000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_210000.zip b/four_room_extensions/DQN_models/DQN_210000.zip new file mode 100644 index 0000000..8a436ae Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_210000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_220000.zip b/four_room_extensions/DQN_models/DQN_220000.zip new file mode 100644 index 0000000..742ee5f Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_220000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_230000.zip b/four_room_extensions/DQN_models/DQN_230000.zip new file mode 100644 index 0000000..65fcb67 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_230000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_240000.zip b/four_room_extensions/DQN_models/DQN_240000.zip new file mode 100644 index 0000000..503b41d Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_240000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_250000.zip b/four_room_extensions/DQN_models/DQN_250000.zip new file mode 100644 index 0000000..384c1a8 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_250000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_260000.zip b/four_room_extensions/DQN_models/DQN_260000.zip new file mode 100644 index 0000000..236a085 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_260000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_270000.zip b/four_room_extensions/DQN_models/DQN_270000.zip new file mode 100644 index 0000000..7d2e56d Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_270000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_280000.zip b/four_room_extensions/DQN_models/DQN_280000.zip new file mode 100644 index 0000000..d04a6ce Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_280000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_290000.zip b/four_room_extensions/DQN_models/DQN_290000.zip new file mode 100644 index 0000000..f509d98 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_290000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_30000.zip b/four_room_extensions/DQN_models/DQN_30000.zip new file mode 100644 index 0000000..7200652 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_30000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_300000.zip b/four_room_extensions/DQN_models/DQN_300000.zip new file mode 100644 index 0000000..1415fa4 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_300000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_310000.zip b/four_room_extensions/DQN_models/DQN_310000.zip new file mode 100644 index 0000000..dea12db Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_310000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_320000.zip b/four_room_extensions/DQN_models/DQN_320000.zip new file mode 100644 index 0000000..566f163 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_320000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_330000.zip b/four_room_extensions/DQN_models/DQN_330000.zip new file mode 100644 index 0000000..3a0fef1 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_330000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_340000.zip b/four_room_extensions/DQN_models/DQN_340000.zip new file mode 100644 index 0000000..039bd13 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_340000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_350000.zip b/four_room_extensions/DQN_models/DQN_350000.zip new file mode 100644 index 0000000..b619713 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_350000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_360000.zip b/four_room_extensions/DQN_models/DQN_360000.zip new file mode 100644 index 0000000..93b80d8 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_360000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_370000.zip b/four_room_extensions/DQN_models/DQN_370000.zip new file mode 100644 index 0000000..63d5878 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_370000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_380000.zip b/four_room_extensions/DQN_models/DQN_380000.zip new file mode 100644 index 0000000..6e9d2f1 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_380000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_390000.zip b/four_room_extensions/DQN_models/DQN_390000.zip new file mode 100644 index 0000000..6c31102 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_390000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_40000.zip b/four_room_extensions/DQN_models/DQN_40000.zip new file mode 100644 index 0000000..2992590 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_40000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_400000.zip b/four_room_extensions/DQN_models/DQN_400000.zip new file mode 100644 index 0000000..14ac80d Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_400000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_410000.zip b/four_room_extensions/DQN_models/DQN_410000.zip new file mode 100644 index 0000000..4f346a0 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_410000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_420000.zip b/four_room_extensions/DQN_models/DQN_420000.zip new file mode 100644 index 0000000..badf24c Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_420000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_430000.zip b/four_room_extensions/DQN_models/DQN_430000.zip new file mode 100644 index 0000000..6cb4191 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_430000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_440000.zip b/four_room_extensions/DQN_models/DQN_440000.zip new file mode 100644 index 0000000..31e0bfa Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_440000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_450000.zip b/four_room_extensions/DQN_models/DQN_450000.zip new file mode 100644 index 0000000..cdad46f Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_450000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_460000.zip b/four_room_extensions/DQN_models/DQN_460000.zip new file mode 100644 index 0000000..361c971 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_460000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_470000.zip b/four_room_extensions/DQN_models/DQN_470000.zip new file mode 100644 index 0000000..af0fbb2 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_470000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_480000.zip b/four_room_extensions/DQN_models/DQN_480000.zip new file mode 100644 index 0000000..7d2743a Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_480000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_490000.zip b/four_room_extensions/DQN_models/DQN_490000.zip new file mode 100644 index 0000000..934de04 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_490000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_50000.zip b/four_room_extensions/DQN_models/DQN_50000.zip new file mode 100644 index 0000000..64f930e Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_50000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_500000.zip b/four_room_extensions/DQN_models/DQN_500000.zip new file mode 100644 index 0000000..3ee6683 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_500000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_60000.zip b/four_room_extensions/DQN_models/DQN_60000.zip new file mode 100644 index 0000000..2c9ce34 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_60000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_70000.zip b/four_room_extensions/DQN_models/DQN_70000.zip new file mode 100644 index 0000000..9971b88 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_70000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_80000.zip b/four_room_extensions/DQN_models/DQN_80000.zip new file mode 100644 index 0000000..8fd3ebf Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_80000.zip differ diff --git a/four_room_extensions/DQN_models/DQN_90000.zip b/four_room_extensions/DQN_models/DQN_90000.zip new file mode 100644 index 0000000..d40c9f2 Binary files /dev/null and b/four_room_extensions/DQN_models/DQN_90000.zip differ diff --git a/four_room_extensions/DQN_models/performance_per_model.txt b/four_room_extensions/DQN_models/performance_per_model.txt new file mode 100644 index 0000000..924d4a5 --- /dev/null +++ b/four_room_extensions/DQN_models/performance_per_model.txt @@ -0,0 +1,50 @@ +10_000 - Episode length: 97.65, Std Dev: 14.68 +20_000 - Episode length: 92.90, Std Dev: 24.94 +30_000 - Episode length: 85.62, Std Dev: 34.22 +40_000 - Episode length: 73.97, Std Dev: 42.27 +50_000 - Episode length: 65.05, Std Dev: 45.15 +60_000 - Episode length: 74.03, Std Dev: 42.19 +70_000 - Episode length: 74.17, Std Dev: 41.95 +80_000 - Episode length: 79.22, Std Dev: 38.58 +90_000 - Episode length: 88.47, Std Dev: 30.51 +100_000 - Episode length: 79.30, Std Dev: 38.44 +110_000 - Episode length: 75.10, Std Dev: 40.46 +120_000 - Episode length: 83.88, Std Dev: 35.05 +130_000 - Episode length: 70.42, Std Dev: 42.67 +140_000 - Episode length: 86.40, Std Dev: 32.41 +150_000 - Episode length: 86.35, Std Dev: 32.52 +160_000 - Episode length: 88.40, Std Dev: 30.72 +170_000 - Episode length: 81.75, Std Dev: 36.53 +180_000 - Episode length: 88.58, Std Dev: 30.26 +190_000 - Episode length: 86.50, Std Dev: 32.17 +200_000 - Episode length: 86.50, Std Dev: 32.16 +210_000 - Episode length: 82.20, Std Dev: 35.62 +220_000 - Episode length: 79.33, Std Dev: 38.40 +230_000 - Episode length: 88.50, Std Dev: 30.45 +240_000 - Episode length: 79.58, Std Dev: 37.95 +250_000 - Episode length: 88.50, Std Dev: 30.45 +260_000 - Episode length: 70.05, Std Dev: 43.20 +270_000 - Episode length: 74.80, Std Dev: 40.95 +280_000 - Episode length: 63.33, Std Dev: 44.97 +290_000 - Episode length: 56.73, Std Dev: 45.57 +300_000 - Episode length: 54.48, Std Dev: 45.60 +310_000 - Episode length: 54.02, Std Dev: 46.02 +320_000 - Episode length: 49.62, Std Dev: 45.63 +330_000 - Episode length: 49.65, Std Dev: 45.61 +340_000 - Episode length: 36.35, Std Dev: 41.79 +350_000 - Episode length: 38.30, Std Dev: 42.91 +360_000 - Episode length: 33.70, Std Dev: 40.94 +370_000 - Episode length: 31.52, Std Dev: 39.66 +380_000 - Episode length: 29.12, Std Dev: 38.30 +390_000 - Episode length: 26.98, Std Dev: 36.64 +400_000 - Episode length: 22.48, Std Dev: 32.71 +410_000 - Episode length: 22.48, Std Dev: 32.71 +420_000 - Episode length: 18.15, Std Dev: 27.48 +430_000 - Episode length: 20.38, Std Dev: 30.27 +440_000 - Episode length: 18.15, Std Dev: 27.48 +450_000 - Episode length: 13.68, Std Dev: 20.08 +460_000 - Episode length: 15.85, Std Dev: 24.18 +470_000 - Episode length: 9.30, Std Dev: 3.38 +480_000 - Episode length: 9.30, Std Dev: 3.38 +490_000 - Episode length: 9.30, Std Dev: 3.38 +500_000 - Episode length: 9.30, Std Dev: 3.38 \ No newline at end of file diff --git a/four_room_extensions/datasets/dataset_from_models_300000_350000_390000_450000_470000.pkl b/four_room_extensions/datasets/dataset_from_models_300000_350000_390000_450000_470000.pkl new file mode 100644 index 0000000..e03154e Binary files /dev/null and b/four_room_extensions/datasets/dataset_from_models_300000_350000_390000_450000_470000.pkl differ diff --git a/four_room_extensions/fourrooms_dataset_gen.py b/four_room_extensions/fourrooms_dataset_gen.py index c1fce86..689f6c3 100644 --- a/four_room_extensions/fourrooms_dataset_gen.py +++ b/four_room_extensions/fourrooms_dataset_gen.py @@ -1,7 +1,6 @@ import gymnasium as gym from typing import Any, Dict, List, Union import numpy as np - from four_room.env import FourRoomsEnv from four_room.shortest_path import find_all_action_values from four_room.utils import obs_to_state