diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 00000000..1170b2ca Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 00000000..5ace8ea9 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 00000000..d1c3a970 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 00000000..a7e14154 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 00000000..bbce2765 Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 00000000..b50e4b6b Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 00000000..d6b4c5db Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 00000000..707a576b Binary files /dev/null and b/ML/Pytorch/Basics/dataset/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/ML/Pytorch/Basics/pytorch_simple_fullynet.py b/ML/Pytorch/Basics/pytorch_simple_fullynet.py index 36a399d8..d2da0153 100644 --- a/ML/Pytorch/Basics/pytorch_simple_fullynet.py +++ b/ML/Pytorch/Basics/pytorch_simple_fullynet.py @@ -95,6 +95,8 @@ def forward(self, x): for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(tqdm(train_loader)): # Get data to cuda if possible + print( data.shape) + print( targets.shape) data = data.to(device=device) targets = targets.to(device=device) @@ -102,9 +104,9 @@ def forward(self, x): data = data.reshape(data.shape[0], -1) # Forward - scores = model(data) + scores = model.forward(data) loss = criterion(scores, targets) - + print(f"Loss at epoch {epoch}, batch {batch_idx}: {loss.item()}") # Backward optimizer.zero_grad() loss.backward() @@ -131,7 +133,7 @@ def check_accuracy(loader, model): num_correct = 0 num_samples = 0 - model.eval() + model.eval()#评估模式,这会关闭dropout等 # We don't need to keep track of gradients here so we wrap it in torch.no_grad() with torch.no_grad(): diff --git a/ML/Pytorch/Basics/pytorch_tensorbasics.py b/ML/Pytorch/Basics/pytorch_tensorbasics.py index 3686dfa6..c1ac6a27 100644 --- a/ML/Pytorch/Basics/pytorch_tensorbasics.py +++ b/ML/Pytorch/Basics/pytorch_tensorbasics.py @@ -163,11 +163,12 @@ values, indices = torch.min(x, dim=0) # Can also do x.min(dim=0) abs_x = torch.abs(x) # Returns x where abs function has been applied to every element z = torch.argmax(x, dim=0) # Gets index of the maximum value -z = torch.argmin(x, dim=0) # Gets index of the minimum value +z = torch.argmin(x, dim=0) +print(z)# Gets index of the minimum value mean_x = torch.mean(x.float(), dim=0) # mean requires x to be float z = torch.eq(x, y) # Element wise comparison, in this case z = [False, False, False] sorted_y, indices = torch.sort(y, dim=0, descending=False) - +print(indices) z = torch.clamp(x, min=0) # All values < 0 set to 0 and values > 0 unchanged (this is exactly ReLU function) # If you want to values over max_val to be clamped, do torch.clamp(x, min=min_val, max=max_val) @@ -207,7 +208,7 @@ rows = torch.tensor([1, 0]) cols = torch.tensor([4, 0]) print(x[rows, cols]) # Gets second row fifth column and first row first column - +# which is same as doing: [x[1,4], x[0,0]] 高级索引 # More advanced indexing x = torch.arange(10) print(x[(x < 2) | (x > 8)]) # will be [0, 1, 9] @@ -216,7 +217,9 @@ # Useful operations for indexing print( torch.where(x > 5, x, x * 2) -) # gives [0, 2, 4, 6, 8, 10, 6, 7, 8, 9], all values x > 5 yield x, else x*2 +) +#满足第一个条件执行第二个 反之执行第三个 +# gives [0, 2, 4, 6, 8, 10, 6, 7, 8, 9], all values x > 5 yield x, else x*2 x = torch.tensor([0, 0, 1, 2, 2, 3, 4]).unique() # x = [0, 1, 2, 3, 4] print( x.ndimension() @@ -231,7 +234,7 @@ # ============================================================= # x = torch.arange(9) - +print(x.shape) # Shape is [9] # Let's say we want to reshape it to be 3x3 x_3x3 = x.view(3, 3) @@ -256,7 +259,7 @@ # using pointers to construct these matrices). This is a bit complicated and I need to explore this more # as well, at least you know it's a problem to be cautious of! A solution is to do the following print(y.contiguous().view(9)) # Calling .contiguous() before view and it works - +#内存跳动 # Moving on to another operation, let's say we want to add two tensors dimensions togethor x1 = torch.rand(2, 5) x2 = torch.rand(2, 5) @@ -284,7 +287,7 @@ z = torch.chunk(x, chunks=2, dim=1) print(z[0].shape) print(z[1].shape) - +#分成若干个子张量 # Let's say we want to add an additional dimension x = torch.arange( 10 diff --git a/Test/tensor.py b/Test/tensor.py new file mode 100644 index 00000000..eaa05df1 --- /dev/null +++ b/Test/tensor.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +my_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) +x=torch.empty(size=(3,3)).uniform_(0,1) +y=torch.diag(torch.ones(3)) +z=torch.ones(3) +print(x) +print(y) +print(z) +import numpy as np +a = np.array([1, 2, 3]) +b = torch.from_numpy(a)#这里是将numpy数组转换为tensor +print(a) +print(b) +c=b.numpy()#这里是将tensor转换为numpy数组 +print(c.dtype) +import torch +x = torch.tensor([1, 2, 3]) +print(torch.diag(x)) +# 输出: +# tensor([[1, 0, 0], +# [0, 2, 0], +# [0, 0, 3]]) +A = torch.tensor([[1, 2], [3, 4]]) +print(torch.diag(A)) +# 输出: tensor([1, 4]) +p=torch.rand(3, 4) +print(p) +q=torch.eye(4) +print(q) +z=torch.empty(3,4).normal_(mean=0,std=1) +print(z) +j=torch.arange(1,10,2) +print(j) +k=torch.empty(3,4) +print(k) \ No newline at end of file diff --git a/_downloads/c195adbae0504b6504c93e0fd18235ce/mario_rl_tutorial.ipynb b/_downloads/c195adbae0504b6504c93e0fd18235ce/mario_rl_tutorial.ipynb new file mode 100644 index 00000000..5151c183 --- /dev/null +++ b/_downloads/c195adbae0504b6504c93e0fd18235ce/mario_rl_tutorial.ipynb @@ -0,0 +1,1178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "id": "-uf9wy7_D2qF" + }, + "outputs": [], + "source": [ + "# For tips on running notebooks in Google Colab, see\n", + "# https://docs.pytorch.org/tutorials/beginner/colab\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2LYmEix3D2qH" + }, + "source": [ + "Train a Mario-playing RL Agent\n", + "==============================\n", + "\n", + "**Authors:** [Yuansong Feng](https://github.com/YuansongFeng), [Suraj\n", + "Subramanian](https://github.com/suraj813), [Howard\n", + "Wang](https://github.com/hw26), [Steven\n", + "Guo](https://github.com/GuoYuzhang).\n", + "\n", + "This tutorial walks you through the fundamentals of Deep Reinforcement\n", + "Learning. At the end, you will implement an AI-powered Mario (using\n", + "[Double Deep Q-Networks](https://arxiv.org/pdf/1509.06461.pdf)) that can\n", + "play the game by itself.\n", + "\n", + "Although no prior knowledge of RL is necessary for this tutorial, you\n", + "can familiarize yourself with these RL\n", + "[concepts](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html),\n", + "and have this handy\n", + "[cheatsheet](https://colab.research.google.com/drive/1eN33dPVtdPViiS1njTW_-r-IYCDTFU7N)\n", + "as your companion. The full code is available\n", + "[here](https://github.com/yuansongFeng/MadMario/).\n", + "\n", + "![](https://pytorch.org/tutorials/_static/img/mario.gif)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gIa9lvU7D2qJ" + }, + "source": [ + "``` {.bash}\n", + "%%bash\n", + "pip install gym-super-mario-bros==7.4.0\n", + "pip install tensordict==0.3.0\n", + "pip install torchrl==0.3.0\n", + "```\n" + ] + }, + { + "cell_type": "code", + "source": [ + "%%bash\n", + "pip install gym==0.26.2 gym-super-mario-bros==7.3.0 nes-py==8.1.0\n", + "pip install tensordict\n", + "pip install torchrl" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aYk3GjmsEmAJ", + "outputId": "ddd2971f-f92f-4fae-f9e9-f20015cb8fff" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: gym==0.25.2 in /usr/local/lib/python3.12/dist-packages (0.25.2)\n", + "Collecting gym-super-mario-bros==7.3.0\n", + " Downloading gym_super_mario_bros-7.3.0-py2.py3-none-any.whl.metadata (9.4 kB)\n", + "Collecting nes-py==8.1.0\n", + " Downloading nes_py-8.1.0.tar.gz (73 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 73.1/73.1 kB 3.6 MB/s eta 0:00:00\n", + " Preparing metadata (setup.py): started\n", + " Preparing metadata (setup.py): finished with status 'done'\n", + "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.12/dist-packages (from gym==0.25.2) (2.0.2)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from gym==0.25.2) (3.1.1)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.12/dist-packages (from gym==0.25.2) (0.1.0)\n", + "Collecting pyglet>=1.3.2 (from nes-py==8.1.0)\n", + " Downloading pyglet-2.1.9-py3-none-any.whl.metadata (7.7 kB)\n", + "Requirement already satisfied: tqdm>=4.19.5 in /usr/local/lib/python3.12/dist-packages (from nes-py==8.1.0) (4.67.1)\n", + "Downloading gym_super_mario_bros-7.3.0-py2.py3-none-any.whl (198 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 198.6/198.6 kB 10.6 MB/s eta 0:00:00\n", + "Downloading pyglet-2.1.9-py3-none-any.whl (1.0 MB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 36.1 MB/s eta 0:00:00\n", + "Building wheels for collected packages: nes-py\n", + " Building wheel for nes-py (setup.py): started\n", + " Building wheel for nes-py (setup.py): finished with status 'done'\n", + " Created wheel for nes-py: filename=nes_py-8.1.0-cp312-cp312-linux_x86_64.whl size=504502 sha256=7d20998ab7c44177d003f2459ee1bb2136527fcd190496ab3e06445e79e88fd5\n", + " Stored in directory: /root/.cache/pip/wheels/a7/83/d9/f251e11d21aa7223824a74c79b52f63a3f5175ac20e9bac221\n", + "Successfully built nes-py\n", + "Installing collected packages: pyglet, nes-py, gym-super-mario-bros\n", + "Successfully installed gym-super-mario-bros-7.3.0 nes-py-8.1.0 pyglet-2.1.9\n", + "Collecting tensordict\n", + " Downloading tensordict-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.3 kB)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from tensordict) (2.8.0+cu126)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from tensordict) (2.0.2)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.12/dist-packages (from tensordict) (3.1.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from tensordict) (25.0)\n", + "Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages (from tensordict) (8.7.0)\n", + "Requirement already satisfied: orjson in /usr/local/lib/python3.12/dist-packages (from tensordict) (3.11.3)\n", + "Collecting pyvers<0.2.0,>=0.1.0 (from tensordict)\n", + " Downloading pyvers-0.1.0-py3-none-any.whl.metadata (5.4 kB)\n", + "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib_metadata->tensordict) (3.23.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.5)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.1.6)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (2.27.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch->tensordict) (3.4.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->tensordict) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->tensordict) (3.0.3)\n", + "Downloading tensordict-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl (449 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 450.0/450.0 kB 12.5 MB/s eta 0:00:00\n", + "Downloading pyvers-0.1.0-py3-none-any.whl (10 kB)\n", + "Installing collected packages: pyvers, tensordict\n", + "Successfully installed pyvers-0.1.0 tensordict-0.10.0\n", + "Collecting torchrl\n", + " Downloading torchrl-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (48 kB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 48.4/48.4 kB 2.8 MB/s eta 0:00:00\n", + "Requirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.12/dist-packages (from torchrl) (2.8.0+cu126)\n", + "Requirement already satisfied: pyvers in /usr/local/lib/python3.12/dist-packages (from torchrl) (0.1.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchrl) (2.0.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from torchrl) (25.0)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.12/dist-packages (from torchrl) (3.1.1)\n", + "Requirement already satisfied: tensordict<0.11.0,>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from torchrl) (0.10.0)\n", + "Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages (from tensordict<0.11.0,>=0.10.0->torchrl) (8.7.0)\n", + "Requirement already satisfied: orjson in /usr/local/lib/python3.12/dist-packages (from tensordict<0.11.0,>=0.10.0->torchrl) (3.11.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.5)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.1.6)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (2.27.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->torchrl) (3.4.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.1.0->torchrl) (1.3.0)\n", + "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.12/dist-packages (from importlib_metadata->tensordict<0.11.0,>=0.10.0->torchrl) (3.23.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.1.0->torchrl) (3.0.3)\n", + "Downloading torchrl-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl (1.8 MB)\n", + " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 42.0 MB/s eta 0:00:00\n", + "Installing collected packages: torchrl\n", + "Successfully installed torchrl-0.10.0\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Vp25M-ClD2qJ", + "outputId": "462e5ecc-ef04-4258-9ca6-5bb923c0027b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.\n", + "Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.\n", + "See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.\n", + "/usr/local/lib/python3.12/dist-packages/jupyter_client/session.py:203: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " return datetime.utcnow().replace(tzinfo=utc)\n" + ] + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torchvision import transforms as T\n", + "from PIL import Image\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from collections import deque\n", + "import random, datetime, os\n", + "\n", + "# Gym is an OpenAI toolkit for RL\n", + "import gym\n", + "from gym.spaces import Box\n", + "from gym.wrappers import FrameStack\n", + "\n", + "# NES Emulator for OpenAI Gym\n", + "from nes_py.wrappers import JoypadSpace\n", + "\n", + "# Super Mario environment for OpenAI Gym\n", + "import gym_super_mario_bros\n", + "\n", + "from tensordict import TensorDict\n", + "from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_c-_S9p_D2qJ" + }, + "source": [ + "RL Definitions\n", + "==============\n", + "\n", + "**Environment** The world that an agent interacts with and learns from.\n", + "\n", + "**Action** $a$ : How the Agent responds to the Environment. The set of\n", + "all possible Actions is called *action-space*.\n", + "\n", + "**State** $s$ : The current characteristic of the Environment. The set\n", + "of all possible States the Environment can be in is called\n", + "*state-space*.\n", + "\n", + "**Reward** $r$ : Reward is the key feedback from Environment to Agent.\n", + "It is what drives the Agent to learn and to change its future action. An\n", + "aggregation of rewards over multiple time steps is called **Return**.\n", + "\n", + "**Optimal Action-Value function** $Q^*(s,a)$ : Gives the expected return\n", + "if you start in state $s$, take an arbitrary action $a$, and then for\n", + "each future time step take the action that maximizes returns. $Q$ can be\n", + "said to stand for the \"quality\" of the action in a state. We try to\n", + "approximate this function.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s1MRtEQkD2qJ" + }, + "source": [ + "Environment\n", + "===========\n", + "\n", + "Initialize Environment\n", + "----------------------\n", + "\n", + "In Mario, the environment consists of tubes, mushrooms and other\n", + "components.\n", + "\n", + "When Mario makes an action, the environment responds with the changed\n", + "(next) state, reward and other info.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 530 + }, + "id": "ptVynUw2D2qK", + "outputId": "355001d4-10e8-4f4b-d01d-5ac97dd815d5" + }, + "outputs": [ + { + "output_type": "error", + "ename": "TypeError", + "evalue": "SuperMarioBrosEnv.__init__() got an unexpected keyword argument 'render_mode'", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipython-input-4228824977.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgym_super_mario_bros\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactions\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mactions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m env = gym_super_mario_bros.make(\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;34m\"SuperMarioBros-1-1-v0\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mrender_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# Changed render_mode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, new_step_api, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m 672\u001b[0m )\n\u001b[1;32m 673\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 674\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 675\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;31m# Copies the environment creation specification and kwargs to add to the environment specification details\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, new_step_api, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 662\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv_creator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 663\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 664\u001b[0m if (\n", + "\u001b[0;31mTypeError\u001b[0m: SuperMarioBrosEnv.__init__() got an unexpected keyword argument 'render_mode'" + ] + } + ], + "source": [ + "# Initialize Super Mario environment\n", + "import gym_super_mario_bros\n", + "from nes_py.wrappers import JoypadSpace\n", + "import gym_super_mario_bros.actions as actions\n", + "\n", + "env = gym_super_mario_bros.make(\n", + " \"SuperMarioBros-1-1-v0\",\n", + " render_mode='rgb_array', # Changed render_mode\n", + " apply_api_compatibility=True\n", + ")\n", + "\n", + "# Limit the action-space to\n", + "# 0. walk right\n", + "# 1. jump right\n", + "env = JoypadSpace(env, actions.SIMPLE_MOVEMENT)\n", + "\n", + "env.reset()\n", + "next_state, reward, done, trunc, info = env.step(action=0)\n", + "print(f\"{next_state.shape},\\n {reward},\\n {done},\\n {info}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_1tsdX8ND2qK" + }, + "source": [ + "Preprocess Environment\n", + "======================\n", + "\n", + "Environment data is returned to the agent in `next_state`. As you saw\n", + "above, each state is represented by a `[3, 240, 256]` size array. Often\n", + "that is more information than our agent needs; for instance, Mario's\n", + "actions do not depend on the color of the pipes or the sky!\n", + "\n", + "We use **Wrappers** to preprocess environment data before sending it to\n", + "the agent.\n", + "\n", + "`GrayScaleObservation` is a common wrapper to transform an RGB image to\n", + "grayscale; doing so reduces the size of the state representation without\n", + "losing useful information. Now the size of each state: `[1, 240, 256]`\n", + "\n", + "`ResizeObservation` downsamples each observation into a square image.\n", + "New size: `[1, 84, 84]`\n", + "\n", + "`SkipFrame` is a custom wrapper that inherits from `gym.Wrapper` and\n", + "implements the `step()` function. Because consecutive frames don't vary\n", + "much, we can skip n-intermediate frames without losing much information.\n", + "The n-th frame aggregates rewards accumulated over each skipped frame.\n", + "\n", + "`FrameStack` is a wrapper that allows us to squash consecutive frames of\n", + "the environment into a single observation point to feed to our learning\n", + "model. This way, we can identify if Mario was landing or jumping based\n", + "on the direction of his movement in the previous several frames.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 211 + }, + "id": "OTdrBI4XD2qK", + "outputId": "ffc7e4d9-c2e3-4108-e8d6-a0957523b62f" + }, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "name 'env' is not defined", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipython-input-3910090054.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;31m# Apply Wrappers to environment\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSkipFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGrayScaleObservation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResizeObservation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m84\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'env' is not defined" + ] + } + ], + "source": [ + "class SkipFrame(gym.Wrapper):\n", + " def __init__(self, env, skip):\n", + " \"\"\"Return only every `skip`-th frame\"\"\"\n", + " super().__init__(env)\n", + " self._skip = skip\n", + "\n", + " def step(self, action):\n", + " \"\"\"Repeat action, and sum reward\"\"\"\n", + " total_reward = 0.0\n", + " for i in range(self._skip):\n", + " # Accumulate reward and repeat the same action\n", + " obs, reward, done, trunk, info = self.env.step(action)\n", + " total_reward += reward\n", + " if done:\n", + " break\n", + " return obs, total_reward, done, trunk, info\n", + "\n", + "\n", + "class GrayScaleObservation(gym.ObservationWrapper):\n", + " def __init__(self, env):\n", + " super().__init__(env)\n", + " obs_shape = self.observation_space.shape[:2]\n", + " self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", + "\n", + " def permute_orientation(self, observation):\n", + " # permute [H, W, C] array to [C, H, W] tensor\n", + " observation = np.transpose(observation, (2, 0, 1))\n", + " observation = torch.tensor(observation.copy(), dtype=torch.float)\n", + " return observation\n", + "\n", + " def observation(self, observation):\n", + " observation = self.permute_orientation(observation)\n", + " transform = T.Grayscale()\n", + " observation = transform(observation)\n", + " return observation\n", + "\n", + "\n", + "class ResizeObservation(gym.ObservationWrapper):\n", + " def __init__(self, env, shape):\n", + " super().__init__(env)\n", + " if isinstance(shape, int):\n", + " self.shape = (shape, shape)\n", + " else:\n", + " self.shape = tuple(shape)\n", + "\n", + " obs_shape = self.shape + self.observation_space.shape[2:]\n", + " self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", + "\n", + " def observation(self, observation):\n", + " transforms = T.Compose(\n", + " [T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]\n", + " )\n", + " observation = transforms(observation).squeeze(0)\n", + " return observation\n", + "\n", + "\n", + "# Apply Wrappers to environment\n", + "env = SkipFrame(env, skip=4)\n", + "env = GrayScaleObservation(env)\n", + "env = ResizeObservation(env, shape=84)\n", + "if gym.__version__ < '0.26':\n", + " env = FrameStack(env, num_stack=4, new_step_api=True)\n", + "else:\n", + " env = FrameStack(env, num_stack=4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zdFsFd9OD2qL" + }, + "source": [ + "After applying the above wrappers to the environment, the final wrapped\n", + "state consists of 4 gray-scaled consecutive frames stacked together, as\n", + "shown above in the image on the left. Each time Mario makes an action,\n", + "the environment responds with a state of this structure. The structure\n", + "is represented by a 3-D array of size `[4, 84, 84]`.\n", + "\n", + "![](https://pytorch.org/tutorials/_static/img/mario_env.png)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nWAooKxLD2qL" + }, + "source": [ + "Agent\n", + "=====\n", + "\n", + "We create a class `Mario` to represent our agent in the game. Mario\n", + "should be able to:\n", + "\n", + "- **Act** according to the optimal action policy based on the current\n", + " state (of the environment).\n", + "- **Remember** experiences. Experience = (current state, current\n", + " action, reward, next state). Mario *caches* and later *recalls* his\n", + " experiences to update his action policy.\n", + "- **Learn** a better action policy over time\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pi2mirDYD2qL" + }, + "outputs": [], + "source": [ + "class Mario:\n", + " def __init__():\n", + " pass\n", + "\n", + " def act(self, state):\n", + " \"\"\"Given a state, choose an epsilon-greedy action\"\"\"\n", + " pass\n", + "\n", + " def cache(self, experience):\n", + " \"\"\"Add the experience to memory\"\"\"\n", + " pass\n", + "\n", + " def recall(self):\n", + " \"\"\"Sample experiences from memory\"\"\"\n", + " pass\n", + "\n", + " def learn(self):\n", + " \"\"\"Update online action value (Q) function with a batch of experiences\"\"\"\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PMIFF0aBD2qL" + }, + "source": [ + "In the following sections, we will populate Mario's parameters and\n", + "define his functions.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jklk8voTD2qL" + }, + "source": [ + "Act\n", + "===\n", + "\n", + "For any given state, an agent can choose to do the most optimal action\n", + "(**exploit**) or a random action (**explore**).\n", + "\n", + "Mario randomly explores with a chance of `self.exploration_rate`; when\n", + "he chooses to exploit, he relies on `MarioNet` (implemented in `Learn`\n", + "section) to provide the most optimal action.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-kYLEcR8D2qL" + }, + "outputs": [], + "source": [ + "class Mario:\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " self.state_dim = state_dim\n", + " self.action_dim = action_dim\n", + " self.save_dir = save_dir\n", + "\n", + " self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + " # Mario's DNN to predict the most optimal action - we implement this in the Learn section\n", + " self.net = MarioNet(self.state_dim, self.action_dim).float()\n", + " self.net = self.net.to(device=self.device)\n", + "\n", + " self.exploration_rate = 1\n", + " self.exploration_rate_decay = 0.99999975\n", + " self.exploration_rate_min = 0.1\n", + " self.curr_step = 0\n", + "\n", + " self.save_every = 5e5 # no. of experiences between saving Mario Net\n", + "\n", + " def act(self, state):\n", + " \"\"\"\n", + " Given a state, choose an epsilon-greedy action and update value of step.\n", + "\n", + " Inputs:\n", + " state(``LazyFrame``): A single observation of the current state, dimension is (state_dim)\n", + " Outputs:\n", + " ``action_idx`` (``int``): An integer representing which action Mario will perform\n", + " \"\"\"\n", + " # EXPLORE\n", + " if np.random.rand() < self.exploration_rate:\n", + " action_idx = np.random.randint(self.action_dim)\n", + "\n", + " # EXPLOIT\n", + " else:\n", + " state = state[0].__array__() if isinstance(state, tuple) else state.__array__()\n", + " state = torch.tensor(state, device=self.device).unsqueeze(0)\n", + " action_values = self.net(state, model=\"online\")\n", + " action_idx = torch.argmax(action_values, axis=1).item()\n", + "\n", + " # decrease exploration_rate\n", + " self.exploration_rate *= self.exploration_rate_decay\n", + " self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)\n", + "\n", + " # increment step\n", + " self.curr_step += 1\n", + " return action_idx" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z_ud-4bvD2qM" + }, + "source": [ + "Cache and Recall\n", + "================\n", + "\n", + "These two functions serve as Mario's \"memory\" process.\n", + "\n", + "`cache()`: Each time Mario performs an action, he stores the\n", + "`experience` to his memory. His experience includes the current *state*,\n", + "*action* performed, *reward* from the action, the *next state*, and\n", + "whether the game is *done*.\n", + "\n", + "`recall()`: Mario randomly samples a batch of experiences from his\n", + "memory, and uses that to learn the game.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YHeLPsd4D2qM" + }, + "outputs": [], + "source": [ + "class Mario(Mario): # subclassing for continuity\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device(\"cpu\")))\n", + " self.batch_size = 32\n", + "\n", + " def cache(self, state, next_state, action, reward, done):\n", + " \"\"\"\n", + " Store the experience to self.memory (replay buffer)\n", + "\n", + " Inputs:\n", + " state (``LazyFrame``),\n", + " next_state (``LazyFrame``),\n", + " action (``int``),\n", + " reward (``float``),\n", + " done(``bool``))\n", + " \"\"\"\n", + " def first_if_tuple(x):\n", + " return x[0] if isinstance(x, tuple) else x\n", + " state = first_if_tuple(state).__array__()\n", + " next_state = first_if_tuple(next_state).__array__()\n", + "\n", + " state = torch.tensor(state)\n", + " next_state = torch.tensor(next_state)\n", + " action = torch.tensor([action])\n", + " reward = torch.tensor([reward])\n", + " done = torch.tensor([done])\n", + "\n", + " # self.memory.append((state, next_state, action, reward, done,))\n", + " self.memory.add(TensorDict({\"state\": state, \"next_state\": next_state, \"action\": action, \"reward\": reward, \"done\": done}, batch_size=[]))\n", + "\n", + " def recall(self):\n", + " \"\"\"\n", + " Retrieve a batch of experiences from memory\n", + " \"\"\"\n", + " batch = self.memory.sample(self.batch_size).to(self.device)\n", + " state, next_state, action, reward, done = (batch.get(key) for key in (\"state\", \"next_state\", \"action\", \"reward\", \"done\"))\n", + " return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qrnlw-hYD2qM" + }, + "source": [ + "Learn\n", + "=====\n", + "\n", + "Mario uses the [DDQN algorithm](https://arxiv.org/pdf/1509.06461) under\n", + "the hood. DDQN uses two ConvNets - $Q_{online}$ and $Q_{target}$ - that\n", + "independently approximate the optimal action-value function.\n", + "\n", + "In our implementation, we share feature generator `features` across\n", + "$Q_{online}$ and $Q_{target}$, but maintain separate FC classifiers for\n", + "each. $\\theta_{target}$ (the parameters of $Q_{target}$) is frozen to\n", + "prevent updating by backprop. Instead, it is periodically synced with\n", + "$\\theta_{online}$ (more on this later).\n", + "\n", + "Neural Network\n", + "--------------\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9_aBKXrRD2qM" + }, + "outputs": [], + "source": [ + "class MarioNet(nn.Module):\n", + " \"\"\"mini CNN structure\n", + " input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, output_dim):\n", + " super().__init__()\n", + " c, h, w = input_dim\n", + "\n", + " if h != 84:\n", + " raise ValueError(f\"Expecting input height: 84, got: {h}\")\n", + " if w != 84:\n", + " raise ValueError(f\"Expecting input width: 84, got: {w}\")\n", + "\n", + " self.online = self.__build_cnn(c, output_dim)\n", + "\n", + " self.target = self.__build_cnn(c, output_dim)\n", + " self.target.load_state_dict(self.online.state_dict())\n", + "\n", + " # Q_target parameters are frozen.\n", + " for p in self.target.parameters():\n", + " p.requires_grad = False\n", + "\n", + " def forward(self, input, model):\n", + " if model == \"online\":\n", + " return self.online(input)\n", + " elif model == \"target\":\n", + " return self.target(input)\n", + "\n", + " def __build_cnn(self, c, output_dim):\n", + " return nn.Sequential(\n", + " nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),\n", + " nn.ReLU(),\n", + " nn.Flatten(),\n", + " nn.Linear(3136, 512),\n", + " nn.ReLU(),\n", + " nn.Linear(512, output_dim),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3givTueiD2qM" + }, + "source": [ + "TD Estimate & TD Target\n", + "=======================\n", + "\n", + "Two values are involved in learning:\n", + "\n", + "**TD Estimate** - the predicted optimal $Q^*$ for a given state $s$\n", + "\n", + "$${TD}_e = Q_{online}^*(s,a)$$\n", + "\n", + "**TD Target** - aggregation of current reward and the estimated $Q^*$ in\n", + "the next state $s'$\n", + "\n", + "$$a' = argmax_{a} Q_{online}(s', a)$$\n", + "\n", + "$${TD}_t = r + \\gamma Q_{target}^*(s',a')$$\n", + "\n", + "Because we don't know what next action $a'$ will be, we use the action\n", + "$a'$ maximizes $Q_{online}$ in the next state $s'$.\n", + "\n", + "Notice we use the\n", + "[\\@torch.no\\_grad()](https://pytorch.org/docs/stable/generated/torch.no_grad.html#no-grad)\n", + "decorator on `td_target()` to disable gradient calculations here\n", + "(because we don't need to backpropagate on $\\theta_{target}$).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-2dL1LkWD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.gamma = 0.9\n", + "\n", + " def td_estimate(self, state, action):\n", + " current_Q = self.net(state, model=\"online\")[\n", + " np.arange(0, self.batch_size), action\n", + " ] # Q_online(s,a)\n", + " return current_Q\n", + "\n", + " @torch.no_grad()\n", + " def td_target(self, reward, next_state, done):\n", + " next_state_Q = self.net(next_state, model=\"online\")\n", + " best_action = torch.argmax(next_state_Q, axis=1)\n", + " next_Q = self.net(next_state, model=\"target\")[\n", + " np.arange(0, self.batch_size), best_action\n", + " ]\n", + " return (reward + (1 - done.float()) * self.gamma * next_Q).float()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8u_-CcTID2qN" + }, + "source": [ + "Updating the model\n", + "==================\n", + "\n", + "As Mario samples inputs from his replay buffer, we compute $TD_t$ and\n", + "$TD_e$ and backpropagate this loss down $Q_{online}$ to update its\n", + "parameters $\\theta_{online}$ ($\\alpha$ is the learning rate `lr` passed\n", + "to the `optimizer`)\n", + "\n", + "$$\\theta_{online} \\leftarrow \\theta_{online} + \\alpha \\nabla(TD_e - TD_t)$$\n", + "\n", + "$\\theta_{target}$ does not update through backpropagation. Instead, we\n", + "periodically copy $\\theta_{online}$ to $\\theta_{target}$\n", + "\n", + "$$\\theta_{target} \\leftarrow \\theta_{online}$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UKttGHnvD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)\n", + " self.loss_fn = torch.nn.SmoothL1Loss()\n", + "\n", + " def update_Q_online(self, td_estimate, td_target):\n", + " loss = self.loss_fn(td_estimate, td_target)\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " return loss.item()\n", + "\n", + " def sync_Q_target(self):\n", + " self.net.target.load_state_dict(self.net.online.state_dict())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mj0D6D1TD2qN" + }, + "source": [ + "Save checkpoint\n", + "===============\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-20sIwqQD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def save(self):\n", + " save_path = (\n", + " self.save_dir / f\"mario_net_{int(self.curr_step // self.save_every)}.chkpt\"\n", + " )\n", + " torch.save(\n", + " dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),\n", + " save_path,\n", + " )\n", + " print(f\"MarioNet saved to {save_path} at step {self.curr_step}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AM-1bAUkD2qN" + }, + "source": [ + "Putting it all together\n", + "=======================\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qXqnlMixD2qN" + }, + "outputs": [], + "source": [ + "class Mario(Mario):\n", + " def __init__(self, state_dim, action_dim, save_dir):\n", + " super().__init__(state_dim, action_dim, save_dir)\n", + " self.burnin = 1e4 # min. experiences before training\n", + " self.learn_every = 3 # no. of experiences between updates to Q_online\n", + " self.sync_every = 1e4 # no. of experiences between Q_target & Q_online sync\n", + "\n", + " def learn(self):\n", + " if self.curr_step % self.sync_every == 0:\n", + " self.sync_Q_target()\n", + "\n", + " if self.curr_step % self.save_every == 0:\n", + " self.save()\n", + "\n", + " if self.curr_step < self.burnin:\n", + " return None, None\n", + "\n", + " if self.curr_step % self.learn_every != 0:\n", + " return None, None\n", + "\n", + " # Sample from memory\n", + " state, next_state, action, reward, done = self.recall()\n", + "\n", + " # Get TD Estimate\n", + " td_est = self.td_estimate(state, action)\n", + "\n", + " # Get TD Target\n", + " td_tgt = self.td_target(reward, next_state, done)\n", + "\n", + " # Backpropagate loss through Q_online\n", + " loss = self.update_Q_online(td_est, td_tgt)\n", + "\n", + " return (td_est.mean().item(), loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nPxSO3n3D2qN" + }, + "source": [ + "Logging\n", + "=======\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Vhp225CD2qN" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import time, datetime\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "class MetricLogger:\n", + " def __init__(self, save_dir):\n", + " self.save_log = save_dir / \"log\"\n", + " with open(self.save_log, \"w\") as f:\n", + " f.write(\n", + " f\"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}\"\n", + " f\"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}\"\n", + " f\"{'TimeDelta':>15}{'Time':>20}\\n\"\n", + " )\n", + " self.ep_rewards_plot = save_dir / \"reward_plot.jpg\"\n", + " self.ep_lengths_plot = save_dir / \"length_plot.jpg\"\n", + " self.ep_avg_losses_plot = save_dir / \"loss_plot.jpg\"\n", + " self.ep_avg_qs_plot = save_dir / \"q_plot.jpg\"\n", + "\n", + " # History metrics\n", + " self.ep_rewards = []\n", + " self.ep_lengths = []\n", + " self.ep_avg_losses = []\n", + " self.ep_avg_qs = []\n", + "\n", + " # Moving averages, added for every call to record()\n", + " self.moving_avg_ep_rewards = []\n", + " self.moving_avg_ep_lengths = []\n", + " self.moving_avg_ep_avg_losses = []\n", + " self.moving_avg_ep_avg_qs = []\n", + "\n", + " # Current episode metric\n", + " self.init_episode()\n", + "\n", + " # Timing\n", + " self.record_time = time.time()\n", + "\n", + " def log_step(self, reward, loss, q):\n", + " self.curr_ep_reward += reward\n", + " self.curr_ep_length += 1\n", + " if loss:\n", + " self.curr_ep_loss += loss\n", + " self.curr_ep_q += q\n", + " self.curr_ep_loss_length += 1\n", + "\n", + " def log_episode(self):\n", + " \"Mark end of episode\"\n", + " self.ep_rewards.append(self.curr_ep_reward)\n", + " self.ep_lengths.append(self.curr_ep_length)\n", + " if self.curr_ep_loss_length == 0:\n", + " ep_avg_loss = 0\n", + " ep_avg_q = 0\n", + " else:\n", + " ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)\n", + " ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)\n", + " self.ep_avg_losses.append(ep_avg_loss)\n", + " self.ep_avg_qs.append(ep_avg_q)\n", + "\n", + " self.init_episode()\n", + "\n", + " def init_episode(self):\n", + " self.curr_ep_reward = 0.0\n", + " self.curr_ep_length = 0\n", + " self.curr_ep_loss = 0.0\n", + " self.curr_ep_q = 0.0\n", + " self.curr_ep_loss_length = 0\n", + "\n", + " def record(self, episode, epsilon, step):\n", + " mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)\n", + " mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)\n", + " mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)\n", + " mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)\n", + " self.moving_avg_ep_rewards.append(mean_ep_reward)\n", + " self.moving_avg_ep_lengths.append(mean_ep_length)\n", + " self.moving_avg_ep_avg_losses.append(mean_ep_loss)\n", + " self.moving_avg_ep_avg_qs.append(mean_ep_q)\n", + "\n", + " last_record_time = self.record_time\n", + " self.record_time = time.time()\n", + " time_since_last_record = np.round(self.record_time - last_record_time, 3)\n", + "\n", + " print(\n", + " f\"Episode {episode} - \"\n", + " f\"Step {step} - \"\n", + " f\"Epsilon {epsilon} - \"\n", + " f\"Mean Reward {mean_ep_reward} - \"\n", + " f\"Mean Length {mean_ep_length} - \"\n", + " f\"Mean Loss {mean_ep_loss} - \"\n", + " f\"Mean Q Value {mean_ep_q} - \"\n", + " f\"Time Delta {time_since_last_record} - \"\n", + " f\"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}\"\n", + " )\n", + "\n", + " with open(self.save_log, \"a\") as f:\n", + " f.write(\n", + " f\"{episode:8d}{step:8d}{epsilon:10.3f}\"\n", + " f\"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}\"\n", + " f\"{time_since_last_record:15.3f}\"\n", + " f\"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\\n\"\n", + " )\n", + "\n", + " for metric in [\"ep_lengths\", \"ep_avg_losses\", \"ep_avg_qs\", \"ep_rewards\"]:\n", + " plt.clf()\n", + " plt.plot(getattr(self, f\"moving_avg_{metric}\"), label=f\"moving_avg_{metric}\")\n", + " plt.legend()\n", + " plt.savefig(getattr(self, f\"{metric}_plot\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N_qG20ecD2qO" + }, + "source": [ + "Let's play!\n", + "===========\n", + "\n", + "In this example we run the training loop for 40 episodes, but for Mario\n", + "to truly learn the ways of his world, we suggest running the loop for at\n", + "least 40,000 episodes!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wtG81GdfD2qO" + }, + "outputs": [], + "source": [ + "use_cuda = torch.cuda.is_available()\n", + "print(f\"Using CUDA: {use_cuda}\")\n", + "print()\n", + "\n", + "save_dir = Path(\"checkpoints\") / datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n", + "save_dir.mkdir(parents=True)\n", + "\n", + "mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)\n", + "\n", + "logger = MetricLogger(save_dir)\n", + "\n", + "episodes = 40\n", + "for e in range(episodes):\n", + "\n", + " state = env.reset()\n", + "\n", + " # Play the game!\n", + " while True:\n", + "\n", + " # Run agent on the state\n", + " action = mario.act(state)\n", + "\n", + " # Agent performs action\n", + " next_state, reward, done, trunc, info = env.step(action)\n", + "\n", + " # Remember\n", + " mario.cache(state, next_state, action, reward, done)\n", + "\n", + " # Learn\n", + " q, loss = mario.learn()\n", + "\n", + " # Logging\n", + " logger.log_step(reward, loss, q)\n", + "\n", + " # Update state\n", + " state = next_state\n", + "\n", + " # Check if end of game\n", + " if done or info[\"flag_get\"]:\n", + " break\n", + "\n", + " logger.log_episode()\n", + "\n", + " if (e % 20 == 0) or (e == episodes - 1):\n", + " logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "piaFDm7vD2qO" + }, + "source": [ + "Conclusion\n", + "==========\n", + "\n", + "In this tutorial, we saw how we can use PyTorch to train a game-playing\n", + "AI. You can use the same methods to train an AI to play any of the games\n", + "at the [OpenAI gym](https://gym.openai.com/). Hope you enjoyed this\n", + "tutorial, feel free to reach us at [our\n", + "github](https://github.com/yuansongFeng/MadMario/)!\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file