diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..a93af385 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.npy filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index f1658105..76e5303f 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ __pycache__/ *$py.class .idea/ +/**/Tiny* diff --git a/dlai/00_intro.ipynb b/dlai/00_intro.ipynb new file mode 100644 index 00000000..632292f2 --- /dev/null +++ b/dlai/00_intro.ipynb @@ -0,0 +1,248 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dbfb9335", + "metadata": {}, + "source": [ + "# Introduction to W&B\n", + "\n", + "We will add `wandb` to sprite classification model training, so that we can track and visualize important metrics, gain insights into our model's behavior and make informed decisions for model improvements. We will also see how to compare and analyze different experiments, collaborate with team members, and reproduce results effectively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9ba792c-2baa-4c19-a132-2ed82a759e79", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from pathlib import Path\n", + "from types import SimpleNamespace\n", + "from tqdm.auto import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.optim import Adam\n", + "from utilities import get_dataloaders\n", + "\n", + "import wandb" + ] + }, + { + "cell_type": "markdown", + "id": "2e0bfcc9", + "metadata": {}, + "source": [ + "### Sprite classification\n", + "\n", + "We will build a simple model to classify sprites. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d51a9f7f", + "metadata": {}, + "outputs": [], + "source": [ + "INPUT_SIZE = 3 * 16 * 16\n", + "OUTPUT_SIZE = 5\n", + "HIDDEN_SIZE = 256\n", + "NUM_WORKERS = 2\n", + "CLASSES = [\"hero\", \"non-hero\", \"food\", \"spell\", \"side-facing\"]\n", + "DATA_DIR = Path('./data/')\n", + "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "def get_model(dropout):\n", + " \"Simple MLP with Dropout\"\n", + " return nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(INPUT_SIZE, HIDDEN_SIZE),\n", + " nn.BatchNorm1d(HIDDEN_SIZE),\n", + " nn.ReLU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE)\n", + " ).to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f33f739c-d7ef-4954-ae87-d5bdd6bf25ee", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's define a config object to store our hyperparameters\n", + "config = SimpleNamespace(\n", + " epochs = 2,\n", + " batch_size = 128,\n", + " lr = 1e-5,\n", + " dropout = 0.5,\n", + " slice_size = 10_000,\n", + " valid_pct = 0.2,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5492ebb-2dfa-44ce-af6c-24655e45a2ed", + "metadata": {}, + "outputs": [], + "source": [ + "def train_model(config):\n", + " \"Train a model with a given config\"\n", + " # Start a wandb run\n", + " wandb.init(\n", + " project=\"dlai-intro\",\n", + " config=config,\n", + " )\n", + " # Get the data\n", + " train_dl, valid_dl = get_dataloaders(DATA_DIR, \n", + " config.batch_size, \n", + " config.slice_size, \n", + " config.valid_pct)\n", + " n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)\n", + "\n", + " # A simple MLP model\n", + " model = get_model(config.dropout)\n", + "\n", + " # Make the loss and optimizer\n", + " loss_func = nn.CrossEntropyLoss()\n", + " optimizer = Adam(model.parameters(), lr=config.lr)\n", + "\n", + " example_ct = 0\n", + "\n", + " for epoch in tqdm(range(config.epochs), total=config.epochs):\n", + " model.train()\n", + "\n", + " for step, (images, labels) in enumerate(train_dl):\n", + " images, labels = images.to(DEVICE), labels.to(DEVICE)\n", + "\n", + " outputs = model(images)\n", + " train_loss = loss_func(outputs, labels)\n", + " optimizer.zero_grad()\n", + " train_loss.backward()\n", + " optimizer.step()\n", + "\n", + " example_ct += len(images)\n", + " metrics = {\n", + " \"train/train_loss\": train_loss,\n", + " \"train/epoch\": epoch + 1,\n", + " \"train/example_ct\": example_ct\n", + " }\n", + " # log training metrics to wandb\n", + " wandb.log(metrics)\n", + " \n", + " # Compute validation metrics, log images on last epoch\n", + " val_loss, accuracy = validate_model(model, valid_dl, loss_func)\n", + " # Compute train and validation metrics\n", + " val_metrics = {\n", + " \"val/val_loss\": val_loss,\n", + " \"val/val_accuracy\": accuracy\n", + " }\n", + " # log validation metrics to wandb\n", + " wandb.log(val_metrics)\n", + " \n", + " wandb.finish()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8401cf96", + "metadata": {}, + "outputs": [], + "source": [ + "def validate_model(model, valid_dl, loss_func):\n", + " \"Compute the performance of the model on the validation dataset\"\n", + " model.eval()\n", + " val_loss = 0.0\n", + " correct = 0\n", + "\n", + " with torch.inference_mode():\n", + " for i, (images, labels) in enumerate(valid_dl):\n", + " images, labels = images.to(DEVICE), labels.to(DEVICE)\n", + "\n", + " # Forward pass\n", + " outputs = model(images)\n", + " val_loss += loss_func(outputs, labels) * labels.size(0)\n", + "\n", + " # Compute accuracy and accumulate\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " correct += (predicted == labels).sum().item()\n", + " \n", + " return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)\n" + ] + }, + { + "cell_type": "markdown", + "id": "c4cac7d2", + "metadata": {}, + "source": [ + "### W&B account\n", + "[Sign up](https://wandb.ai/site) for a free account at https://wandb.ai/site and then login to your wandb account to store the results of your experiments and use advanced W&B features. You can also continue to learn in anonymous mode. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "803c37e2-7ff5-46a6-afb7-b80cb69f7501", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.login(anonymous=\"allow\")" + ] + }, + { + "cell_type": "markdown", + "id": "b3df2485", + "metadata": {}, + "source": [ + "### Train model\n", + "Let's train the model with default config and check how it's doing in W&B. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9423c964-f7e3-4d3b-8a24-e70f7f4414c6", + "metadata": {}, + "outputs": [], + "source": [ + "train_model(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e7c186f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dlai/01_diffusion_training-instructor.ipynb b/dlai/01_diffusion_training-instructor.ipynb new file mode 100644 index 00000000..83760618 --- /dev/null +++ b/dlai/01_diffusion_training-instructor.ipynb @@ -0,0 +1,304 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "958524a2-cb56-439e-850e-032dd10478f2", + "metadata": {}, + "source": [ + "# Training a Diffusion Model with Weights and Biases (W&B)\n", + "\n", + "In this notebooks we will instrument the training of a diffusion model with W&B. We will use the Lab3 notebook from the [\"How diffusion models work\"](https://www.deeplearning.ai/short-courses/how-diffusion-models-work/) course. \n", + "We will add:\n", + "- Logging of the training loss and metrics\n", + "- Sampling from the model during training and uploading the samples to W&B\n", + "- Saving the model checkpoints to W&B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "700e687c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from types import SimpleNamespace\n", + "from pathlib import Path\n", + "from tqdm.notebook import tqdm\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from utilities import *\n", + "\n", + "import wandb" + ] + }, + { + "cell_type": "markdown", + "id": "8969ab86-bd9b-475d-96e2-b913b42dec14", + "metadata": {}, + "source": [ + "We encourage you to create an account to get the full user experience from W&B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b88f9513", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.login(anonymous=\"allow\")" + ] + }, + { + "cell_type": "markdown", + "id": "7c0d229a", + "metadata": {}, + "source": [ + "## Setting Things Up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d76c167-7122-4f88-9c9f-5ded96684fa5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# we are storing the parameters to be logged to wandb\n", + "DATA_DIR = Path('./data/')\n", + "SAVE_DIR = Path('./data/weights/')\n", + "SAVE_DIR.mkdir(exist_ok=True, parents=True)\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "config = SimpleNamespace(\n", + " # hyperparameters\n", + " num_samples = 30,\n", + "\n", + " # diffusion hyperparameters\n", + " timesteps = 500,\n", + " beta1 = 1e-4,\n", + " beta2 = 0.02,\n", + "\n", + " # network hyperparameters\n", + " n_feat = 64, # 64 hidden dimension feature\n", + " n_cfeat = 5, # context vector is of size 5\n", + " height = 16, # 16x16 image\n", + " \n", + " # training hyperparameters\n", + " batch_size = 100,\n", + " n_epoch = 32,\n", + " lrate = 1e-3,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9c99dea4", + "metadata": {}, + "source": [ + "### Setup DDPM noise scheduler and sampler (same as in the Diffusion course). \n", + "- perturb_input: Adds noise to the input image at the corresponding timestep on the schedule\n", + "- sample_ddpm_context: Generate images using the DDPM sampler, we will use this function during training to sample from the model regularly and see how our training is progressing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c642e1d", + "metadata": {}, + "outputs": [], + "source": [ + "# setup ddpm sampler functions\n", + "perturb_input, sample_ddpm_context = setup_ddpm(config.beta1, \n", + " config.beta2, \n", + " config.timesteps, \n", + " DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bc9001e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# construct model\n", + "nn_model = ContextUnet(in_channels=3, \n", + " n_feat=config.n_feat, \n", + " n_cfeat=config.n_cfeat, \n", + " height=config.height).to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76c63b85", + "metadata": {}, + "outputs": [], + "source": [ + "# load dataset and construct optimizer\n", + "dataset = CustomDataset.from_np(path=DATA_DIR)\n", + "dataloader = DataLoader(dataset, \n", + " batch_size=config.batch_size, \n", + " shuffle=True, \n", + " num_workers=1)\n", + "optim = torch.optim.Adam(nn_model.parameters(), lr=config.lrate)" + ] + }, + { + "cell_type": "markdown", + "id": "d9ed46d7", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "id": "00b9ef16-1848-476d-a9dd-09175b8f0e3c", + "metadata": {}, + "source": [ + "We choose a fixed context vector with 6 samples of each class to guide our diffusion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d88afdba", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Noise vector\n", + "# x_T ~ N(0, 1), sample initial noise\n", + "noises = torch.randn(config.num_samples, 3, \n", + " config.height, config.height).to(DEVICE) \n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing \n", + " 5).to(DEVICE).float()" + ] + }, + { + "cell_type": "markdown", + "id": "26765a7e-4ddc-449e-95c3-54c58a564738", + "metadata": {}, + "source": [ + "The following training cell takes very long to run on CPU, we have already trained the model for you on a GPU equipped machine.\n", + "\n", + "### You can visit the result of this >> [training here](https://wandb.ai/deeplearning-ai-temp/dlai_sprite_diffusion/runs/gwm91gsw) <<" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5f4af69", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# create a wandb run\n", + "run = wandb.init(project=\"dlai_sprite_diffusion\", \n", + " job_type=\"train\", \n", + " config=config)\n", + "\n", + "# we pass the config back from W&B\n", + "config = wandb.config\n", + "\n", + "for ep in tqdm(range(config.n_epoch), leave=True, total=config.n_epoch):\n", + " # set into train mode\n", + " nn_model.train()\n", + " optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n", + " \n", + " pbar = tqdm(dataloader, leave=False)\n", + " for x, c in pbar: # x: images c: context\n", + " optim.zero_grad()\n", + " x = x.to(DEVICE)\n", + " c = c.to(DEVICE) \n", + " context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.8).to(DEVICE)\n", + " c = c * context_mask.unsqueeze(-1) \n", + " noise = torch.randn_like(x)\n", + " t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(DEVICE) \n", + " x_pert = perturb_input(x, t, noise) \n", + " pred_noise = nn_model(x_pert, t / config.timesteps, c=c) \n", + " loss = F.mse_loss(pred_noise, noise)\n", + " loss.backward() \n", + " optim.step()\n", + "\n", + " # we log the relevant metrics to the workspace\n", + " wandb.log({\"loss\": loss.item(),\n", + " \"lr\": optim.param_groups[0]['lr'],\n", + " \"epoch\": ep})\n", + "\n", + " # save model periodically\n", + " if ep%4==0 or ep == int(config.n_epoch-1):\n", + " nn_model.eval()\n", + " ckpt_file = SAVE_DIR/f\"context_model.pth\"\n", + " torch.save(nn_model.state_dict(), ckpt_file)\n", + " \n", + " \n", + " ###########################################################\n", + " ### COPY TO DEMO NB #######################################\n", + " \n", + " # save model to wandb as an Artifact\n", + " artifact_name = f\"{wandb.run.id}_context_model\"\n", + " at = wandb.Artifact(artifact_name, type=\"model\")\n", + " at.add_file(ckpt_file)\n", + " wandb.log_artifact(at, aliases=[f\"epoch_{ep}\"])\n", + " \n", + " ###########################################################\n", + " ### COPY TO DEMO NB #######################################\n", + " \n", + " # sample the model and log the images to W&B\n", + " samples, _ = sample_ddpm_context(nn_model, \n", + " noises, \n", + " ctx_vector[:config.num_samples])\n", + " wandb.log({\n", + " \"train_samples\": [\n", + " wandb.Image(img) for img in samples.split(1)\n", + " ]})\n", + " \n", + " ###########################################################\n", + " ###########################################################\n", + " \n", + "# finish W&B run\n", + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dlai/01_diffusion_training.ipynb b/dlai/01_diffusion_training.ipynb new file mode 100644 index 00000000..b15b820d --- /dev/null +++ b/dlai/01_diffusion_training.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e57c5e2c-04f8-40b7-9b47-e5e05505cb2c", + "metadata": {}, + "source": [ + "# Training a Diffusion Model with Weights and Biases (W&B)\n", + "\n", + "In this notebooks we will instrument the training of a diffusion model with W&B. We will use the Lab3 notebook from the [\"How diffusion models work\"](https://www.deeplearning.ai/short-courses/how-diffusion-models-work/) course. \n", + "We will add:\n", + "- Logging of the training loss and metrics\n", + "- Sampling from the model during training and uploading the samples to W&B\n", + "- Saving the model checkpoints to W&B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4a34666-2281-49e3-8574-93d57c72771b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from types import SimpleNamespace\n", + "from pathlib import Path\n", + "from tqdm.notebook import tqdm\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from utilities import *\n", + "\n", + "import wandb" + ] + }, + { + "cell_type": "markdown", + "id": "2b4dd4a3-b05e-4a7f-811e-a715573761e9", + "metadata": {}, + "source": [ + "We encourage you to create an account to get the full user experience from W&B" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "904d68fe-7435-48a3-b8af-c4be8675311c", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.login(anonymous=\"allow\")" + ] + }, + { + "cell_type": "markdown", + "id": "02e2b5b2-82e4-4535-aa98-34ae64a808e8", + "metadata": {}, + "source": [ + "## Setting Things Up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4918eda7-6d6b-4f9f-8650-c347ed4a5d1c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# we are storing the parameters to be logged to wandb\n", + "DATA_DIR = Path('./data/')\n", + "SAVE_DIR = Path('./data/weights/')\n", + "SAVE_DIR.mkdir(exist_ok=True, parents=True)\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "config = SimpleNamespace(\n", + " # hyperparameters\n", + " num_samples = 30,\n", + "\n", + " # diffusion hyperparameters\n", + " timesteps = 500,\n", + " beta1 = 1e-4,\n", + " beta2 = 0.02,\n", + "\n", + " # network hyperparameters\n", + " n_feat = 64, # 64 hidden dimension feature\n", + " n_cfeat = 5, # context vector is of size 5\n", + " height = 16, # 16x16 image\n", + " \n", + " # training hyperparameters\n", + " batch_size = 100,\n", + " n_epoch = 32,\n", + " lrate = 1e-3,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1ed92a7b-b6a3-4c0c-a35d-154ec26ed923", + "metadata": {}, + "source": [ + "### Setup DDPM noise scheduler and sampler (same as in the Diffusion course). \n", + "- perturb_input: Adds noise to the input image at the corresponding timestep on the schedule\n", + "- sample_ddpm_context: Generate images using the DDPM sampler, we will use this function during training to sample from the model regularly and see how our training is progressing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ba81b76-6521-4c7c-80bd-bacde0361a34", + "metadata": {}, + "outputs": [], + "source": [ + "# setup ddpm sampler functions\n", + "perturb_input, sample_ddpm_context = setup_ddpm(config.beta1, \n", + " config.beta2, \n", + " config.timesteps, \n", + " DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c83bd768-f709-410a-8062-703bde7997d8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# construct model\n", + "nn_model = ContextUnet(in_channels=3, \n", + " n_feat=config.n_feat, \n", + " n_cfeat=config.n_cfeat, \n", + " height=config.height).to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf98a114-f7aa-4cbd-b08c-d56ad628da21", + "metadata": {}, + "outputs": [], + "source": [ + "# load dataset and construct optimizer\n", + "dataset = CustomDataset.from_np(path=DATA_DIR)\n", + "dataloader = DataLoader(dataset, \n", + " batch_size=config.batch_size, \n", + " shuffle=True, \n", + " num_workers=1)\n", + "optim = torch.optim.Adam(nn_model.parameters(), lr=config.lrate)" + ] + }, + { + "cell_type": "markdown", + "id": "bdccd6e0-850a-41ed-89e7-db629f838770", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "id": "2338bec6-319c-4603-8ae6-0e1fcbdd3a4e", + "metadata": {}, + "source": [ + "We choose a fixed context vector with 6 samples of each class to guide our diffusion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56bfcd32-1a9c-4d0e-8237-77da217f41ae", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Noise vector\n", + "# x_T ~ N(0, 1), sample initial noise\n", + "noises = torch.randn(config.num_samples, 3, \n", + " config.height, config.height).to(DEVICE) \n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing \n", + " 5).to(DEVICE).float()" + ] + }, + { + "cell_type": "markdown", + "id": "e854b7c7-fa0d-4413-8642-f824449d6763", + "metadata": {}, + "source": [ + "The following training cell takes very long to run on CPU, we have already trained the model for you on a GPU equipped machine.\n", + "\n", + "### You can visit the result of this >> [training here](https://wandb.ai/deeplearning-ai-temp/dlai_sprite_diffusion/runs/gwm91gsw) <<" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c87ca8f-2c09-487f-a8bc-7030c2b76492", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# create a wandb run\n", + "run = wandb.init(project=\"dlai_sprite_diffusion\", \n", + " job_type=\"train\", \n", + " config=config)\n", + "\n", + "# we pass the config back from W&B\n", + "config = wandb.config\n", + "\n", + "for ep in tqdm(range(config.n_epoch), leave=True, total=config.n_epoch):\n", + " # set into train mode\n", + " nn_model.train()\n", + " optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n", + " \n", + " pbar = tqdm(dataloader, leave=False)\n", + " for x, c in pbar: # x: images c: context\n", + " optim.zero_grad()\n", + " x = x.to(DEVICE)\n", + " c = c.to(DEVICE) \n", + " context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.8).to(DEVICE)\n", + " c = c * context_mask.unsqueeze(-1) \n", + " noise = torch.randn_like(x)\n", + " t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(DEVICE) \n", + " x_pert = perturb_input(x, t, noise) \n", + " pred_noise = nn_model(x_pert, t / config.timesteps, c=c) \n", + " loss = F.mse_loss(pred_noise, noise)\n", + " loss.backward() \n", + " optim.step()\n", + "\n", + " # we log the relevant metrics to the workspace\n", + " wandb.log({\"loss\": loss.item(),\n", + " \"lr\": optim.param_groups[0]['lr'],\n", + " \"epoch\": ep})\n", + "\n", + " # save model periodically\n", + " if ep%4==0 or ep == int(config.n_epoch-1):\n", + " nn_model.eval()\n", + " ckpt_file = SAVE_DIR/f\"context_model.pth\"\n", + " torch.save(nn_model.state_dict(), ckpt_file\n", + " \n", + " # save model to wandb as an Artifact\n", + " artifact_name = f\"{wandb.run.id}_context_model\"\n", + " at = wandb.Artifact(artifact_name, type=\"model\")\n", + " at.add_file(ckpt_file)\n", + " wandb.log_artifact(at, aliases=[f\"epoch_{ep}\"])\n", + " \n", + " # sample the model and log the images to W&B\n", + " samples, _ = sample_ddpm_context(nn_model, \n", + " noises, \n", + " ctx_vector[:config.num_samples])\n", + " wandb.log({\n", + " \"train_samples\": [\n", + " wandb.Image(img) for img in samples.split(1)\n", + " ]})\n", + " \n", + "# finish W&B run\n", + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dlai/02_diffusion_sampling.ipynb b/dlai/02_diffusion_sampling.ipynb new file mode 100644 index 00000000..cbc1e181 --- /dev/null +++ b/dlai/02_diffusion_sampling.ipynb @@ -0,0 +1,534 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "958524a2-cb56-439e-850e-032dd10478f2", + "metadata": {}, + "source": [ + "# Sampling from a diffusion model\n", + "In this notebooks we will sampled from the previously trained diffusion model.\n", + "- We are going to compare the samples from DDPM and DDIM samplers\n", + "- Visualize mixing samples with conditional diffusion models" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "700e687c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from types import SimpleNamespace\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "from utilities import *\n", + "\n", + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dcaf7a29-782c-4735-991f-4408f5ec6128", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: (1) Private W&B dashboard, no account required\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: (2) Use an existing W&B account\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb: Enter your choice: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: You chose 'Private W&B dashboard, no account required'\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /Users/tcapelle/.netrc\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wandb.login(anonymous=\"allow\")" + ] + }, + { + "cell_type": "markdown", + "id": "7c0d229a", + "metadata": {}, + "source": [ + "# Setting Things Up" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "54c3a942", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Wandb Params\n", + "MODEL_ARTIFACT = \"deeplearning-ai-temp/model-registry/SpriteGen:staging\" \n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "config = SimpleNamespace(\n", + " # hyperparameters\n", + " num_samples = 30,\n", + " \n", + " # ddpm sampler hyperparameters\n", + " timesteps = 500,\n", + " beta1 = 1e-4,\n", + " beta2 = 0.02,\n", + " \n", + " # ddim sampler hp\n", + " ddim_n = 25,\n", + " \n", + " # network hyperparameters\n", + " height = 16,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bb43f98f", + "metadata": {}, + "source": [ + "In the previous notebook we saved the best model as a wandb Artifact (our way of storing files during runs). We will now load the model from wandb and set up the sampling loop." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8ab66255", + "metadata": {}, + "outputs": [], + "source": [ + "def load_model(model_artifact_name):\n", + " \"Load the model from wandb artifacts\"\n", + " api = wandb.Api()\n", + " artifact = api.artifact(model_artifact_name, type=\"model\")\n", + " model_path = Path(artifact.download())\n", + "\n", + " # recover model info from the registry\n", + " producer_run = artifact.logged_by()\n", + "\n", + " # load the weights dictionary\n", + " model_weights = torch.load(model_path/\"context_model.pth\", \n", + " map_location=\"cpu\")\n", + "\n", + " # create the model\n", + " model = ContextUnet(in_channels=3, \n", + " n_feat=producer_run.config[\"n_feat\"], \n", + " n_cfeat=producer_run.config[\"n_cfeat\"], \n", + " height=producer_run.config[\"height\"])\n", + " \n", + " # load the weights into the model\n", + " model.load_state_dict(model_weights)\n", + "\n", + " # set the model to eval mode\n", + " model.eval()\n", + " return model.to(DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b47633e2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n" + ] + } + ], + "source": [ + "nn_model = load_model(MODEL_ARTIFACT)" + ] + }, + { + "cell_type": "markdown", + "id": "fe8eb277", + "metadata": {}, + "source": [ + "## Sampling" + ] + }, + { + "cell_type": "markdown", + "id": "45d92c52-8a11-450c-bc78-ffa221af2fa3", + "metadata": {}, + "source": [ + "We will sample and log the generated samples to wandb." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "146424d3", + "metadata": {}, + "outputs": [], + "source": [ + "_, sample_ddpm_context = setup_ddpm(config.beta1, \n", + " config.beta2, \n", + " config.timesteps, \n", + " DEVICE)" + ] + }, + { + "cell_type": "markdown", + "id": "00b9ef16-1848-476d-a9dd-09175b8f0e3c", + "metadata": {}, + "source": [ + "Let's define a set of noises and a context vector to condition on." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d88afdba", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Noise vector\n", + "# x_T ~ N(0, 1), sample initial noise\n", + "noises = torch.randn(config.num_samples, 3, \n", + " config.height, config.height).to(DEVICE) \n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing \n", + " 5).to(DEVICE).float()" + ] + }, + { + "cell_type": "markdown", + "id": "1cbf9ef8-619a-4052-a138-a88c0f0f8b0b", + "metadata": {}, + "source": [ + "Let's bring that faster DDIM sampler from the diffusion course." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9c1a945d", + "metadata": {}, + "outputs": [], + "source": [ + "sample_ddim_context = setup_ddim(config.beta1, \n", + " config.beta2, \n", + " config.timesteps, \n", + " DEVICE)" + ] + }, + { + "cell_type": "markdown", + "id": "90b838be-8fa1-4c12-9c4f-e40dfacc08e1", + "metadata": {}, + "source": [ + "### Sampling:\n", + "let's compute ddpm samples as before" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "89e24210-4885-4559-92e1-db10566ef5ea", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/500 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.4" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/tcapelle/work/edu/dlai/wandb/run-20230719_144552-50ekio0x" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run fresh-plasma-6 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/deeplearning-ai-temp/dlai_sprite_diffusion" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/deeplearning-ai-temp/dlai_sprite_diffusion/runs/50ekio0x" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Waiting for W&B process to finish... (success)." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run fresh-plasma-6 at: https://wandb.ai/deeplearning-ai-temp/dlai_sprite_diffusion/runs/50ekio0x
Synced 6 W&B file(s), 1 media file(s), 94 artifact file(s) and 1 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20230719_144552-50ekio0x/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with wandb.init(project=\"dlai_sprite_diffusion\", \n", + " job_type=\"samplers_battle\", \n", + " config=config):\n", + " \n", + " wandb.log({\"samplers_table\":table})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7df56d25", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dlai/03_llm_eval.ipynb b/dlai/03_llm_eval.ipynb new file mode 100644 index 00000000..567bb33a --- /dev/null +++ b/dlai/03_llm_eval.ipynb @@ -0,0 +1,577 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "98d52240-af93-4c87-a11e-309b23bdae9c", + "metadata": {}, + "outputs": [], + "source": [ + "# Install wandb-addons, this will be added to wandb soon\n", + "# !git clone https://github.com/soumik12345/wandb-addons.git\n", + "# !pip install ./wandb-addons[prompts] openai wandb -qqq" + ] + }, + { + "cell_type": "markdown", + "id": "53c0d4d6-3d2b-45e5-90fa-ba7953496ec2", + "metadata": {}, + "source": [ + "# LLM Evaluation and Tracing with W&B\n", + "\n", + "## 1. Using Tables for Evaluation\n", + "\n", + "In this section, we will call OpenAI LLM to generate names of our game assets. We will use W&B Tables to evaluate the generations. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6512739b-fe35-4901-acb3-05df46b5ed9c", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import time\n", + "import datetime\n", + "\n", + "import openai\n", + "\n", + "from tenacity import (\n", + " retry,\n", + " stop_after_attempt,\n", + " wait_random_exponential, # for exponential backoff\n", + ") \n", + "import wandb\n", + "from wandb_addons.prompts import Trace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83639bac-5860-4db1-9867-7c89f3ca25a6", + "metadata": {}, + "outputs": [], + "source": [ + "PROJECT = \"dlai-llm\"\n", + "MODEL_NAME = \"gpt-3.5-turbo\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb575380", + "metadata": {}, + "outputs": [], + "source": [ + "# wandb.login() # uncomment if you want to login to wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c304c2b-dcd8-463c-aba4-aa47094dc16b", + "metadata": {}, + "outputs": [], + "source": [ + "run = wandb.init(project=PROJECT, job_type=\"generation\", anonymous=\"allow\")" + ] + }, + { + "cell_type": "markdown", + "id": "4e7bcf11", + "metadata": {}, + "source": [ + "### Simple generations\n", + "Let's start by generating names for our game assets using OpenAI `ChatCompletion`, and saving the resulting generations in W&B Tables. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2ab394b-295b-4cfa-aade-aa274003a56a", + "metadata": {}, + "outputs": [], + "source": [ + "@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))\n", + "def completion_with_backoff(**kwargs):\n", + " return openai.ChatCompletion.create(**kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "736fe64f-5cca-4316-8842-588b948193de", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_and_print(system_prompt, user_prompt, table, n=5):\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt},\n", + " ]\n", + " start_time = time.time()\n", + " responses = completion_with_backoff(\n", + " model=MODEL_NAME,\n", + " messages=messages,\n", + " n = n,\n", + " )\n", + " elapsed_time = time.time() - start_time\n", + " for response in responses.choices:\n", + " generation = response.message.content\n", + " print(generation)\n", + " table.add_data(system_prompt,\n", + " user_prompt,\n", + " [response.message.content for response in responses.choices],\n", + " elapsed_time,\n", + " datetime.datetime.fromtimestamp(responses.created),\n", + " responses.model,\n", + " responses.usage.prompt_tokens,\n", + " responses.usage.completion_tokens,\n", + " responses.usage.total_tokens\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "690e6e0a-193b-41c8-86c4-526f8061dd94", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"You are a creative copywriter.\n", + "You're given a category of game asset, and your goal is to design a name of that asset.\n", + "The game is set in a fantasy world where everyone laughs and respects each other, \n", + "while celebrating diversity.\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "395880fa", + "metadata": {}, + "outputs": [], + "source": [ + "# Define W&B Table to store generations\n", + "columns = [\"system_prompt\", \"user_prompt\", \"generations\", \"elapsed_time\", \"timestamp\",\\\n", + " \"model\", \"prompt_tokens\", \"completion_tokens\", \"total_tokens\"]\n", + "table = wandb.Table(columns=columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fb07587", + "metadata": {}, + "outputs": [], + "source": [ + "user_prompt = \"hero\"\n", + "generate_and_print(system_prompt, user_prompt, table)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8343121b-2d47-47d1-b343-ec2393b8f02f", + "metadata": {}, + "outputs": [], + "source": [ + "user_prompt = \"jewel\"\n", + "generate_and_print(system_prompt, user_prompt, table)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3266487e-150b-4dd8-9555-94e94a66aac1", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.log({\"simple_generations\": table})\n", + "run.finish()" + ] + }, + { + "cell_type": "markdown", + "id": "16d6d513-389d-4c67-a942-a922bce6ff1a", + "metadata": {}, + "source": [ + "## 2. Using Tracer to log more complex chains\n", + "\n", + "How can we get more creative outputs? Let's design an LLM chain that will first randomly pick a fantasy world, and then generate character names. We will demonstrate how to use Tracer in such scenario. We will log the inputs and outputs, start and end times, whether the OpenAI call was successful, the token usage, and additional metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c9fd404-51fd-44cf-b41e-b81dc589a4af", + "metadata": {}, + "outputs": [], + "source": [ + "worlds = [\n", + " \"a mystic medieval island inhabited by intelligent and funny frogs\",\n", + " \"a modern castle sitting on top of a volcano in a faraway galaxy\",\n", + " \"a digital world inhabited by friendly machine learning engineers\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0db1e20a-87a8-4386-9a8d-727db9569cd7", + "metadata": {}, + "outputs": [], + "source": [ + "# define your conifg\n", + "model_name = \"gpt-3.5-turbo\"\n", + "temperature = 0.7\n", + "system_message = \"\"\"You are a creative copywriter. \n", + "You're given a category of game asset and a fantasy world.\n", + "Your goal is to design a name of that asset.\n", + "Provide the resulting name only, no additional description.\n", + "Single name, max 3 words output, remember!\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a86f95e-ed0d-4989-8c1d-5b88cdac7999", + "metadata": {}, + "outputs": [], + "source": [ + "def run_creative_chain(query):\n", + " # part 1 - a chain is started...\n", + " start_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n", + "\n", + " root_span = Trace(\n", + " name=\"MyCreativeChain\",\n", + " kind=\"chain\",\n", + " start_time_ms=start_time_ms,\n", + " metadata={\"user\": \"student_1\"},\n", + " model_dict={\"_kind\": \"CreativeChain\"}\n", + " )\n", + "\n", + " # part 2 - your chain picks a fantasy world\n", + " time.sleep(3)\n", + " world = random.choice(worlds)\n", + " expanded_prompt = f'Game asset category: {query}; fantasy world description: {world}'\n", + " tool_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n", + "\n", + " # create a Tool span \n", + " tool_span = Trace(\n", + " name=\"WorldPicker\",\n", + " kind=\"tool\",\n", + " status_code=\"success\",\n", + " start_time_ms=start_time_ms,\n", + " end_time_ms=tool_end_time_ms,\n", + " inputs={\"input\": query},\n", + " outputs={\"result\": expanded_prompt},\n", + " model_dict={\"_kind\": \"tool\", \"num_worlds\": len(worlds)}\n", + " )\n", + "\n", + " # add the TOOL span as a child of the root\n", + " root_span.add_child(tool_span)\n", + "\n", + " # part 3 - the LLMChain calls an OpenAI LLM...\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\"role\": \"user\", \"content\": expanded_prompt}\n", + " ]\n", + "\n", + " response = openai.ChatCompletion.create(model=model_name,\n", + " messages=messages,\n", + " max_tokens=12,\n", + " temperature=temperature) \n", + "\n", + " llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n", + " response_text = response[\"choices\"][0][\"message\"][\"content\"]\n", + " token_usage = response[\"usage\"].to_dict()\n", + "\n", + " llm_span = Trace(\n", + " name=\"OpenAI\",\n", + " kind=\"llm\",\n", + " status_code=\"success\",\n", + " metadata={\"temperature\":temperature,\n", + " \"token_usage\": token_usage, \n", + " \"model_name\":model_name},\n", + " start_time_ms=tool_end_time_ms,\n", + " end_time_ms=llm_end_time_ms,\n", + " inputs={\"system_prompt\":system_message, \"query\":expanded_prompt},\n", + " outputs={\"response\": response_text},\n", + " model_dict={\"_kind\": \"Openai\", \"engine\": response[\"model\"], \"model\": response[\"object\"]}\n", + " )\n", + "\n", + " # add the LLM span as a child of the Chain span...\n", + " root_span.add_child(llm_span)\n", + "\n", + " # update the end time of the Chain span\n", + " root_span.add_inputs_and_outputs(\n", + " inputs={\"query\":query},\n", + " outputs={\"response\": response_text})\n", + "\n", + " # update the Chain span's end time\n", + " root_span._span.end_time_ms = llm_end_time_ms\n", + "\n", + " # part 4 - log all spans to W&B by logging the root span\n", + " root_span.log(name=\"creative_trace\")\n", + " print(f\"Result: {response_text}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8500843-6d4b-4fc6-93b9-4cadf5813e4a", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's start a new wandb run\n", + "wandb.init(project=PROJECT, job_type=\"generation\", anonymous=\"allow\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7409a004", + "metadata": {}, + "outputs": [], + "source": [ + "run_creative_chain(\"hero\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "538d7bf3-4ae1-4b57-8a96-a34ea0614ec3", + "metadata": {}, + "outputs": [], + "source": [ + "run_creative_chain(\"jewel\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45de1fb0-3630-4673-8ac0-0dffe0a52071", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "id": "1ccc075f-32bf-4451-b7ad-ab2a49cc86b6", + "metadata": {}, + "source": [ + "## Langchain agent\n", + "\n", + "In the third scenario, we'll introduce an agent that will use tools such as WorldPicker and NameValidator to come up with the ultimate name. We will also use Langchain here and demonstrate its W&B integration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "726e0a6a-699b-434d-8c51-7542b4f981dd", + "metadata": {}, + "outputs": [], + "source": [ + "# Import things that are needed generically\n", + "from langchain.agents import AgentType, initialize_agent\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.tools import BaseTool\n", + "\n", + "from typing import Optional\n", + "\n", + "from langchain.callbacks.manager import (\n", + " AsyncCallbackManagerForToolRun,\n", + " CallbackManagerForToolRun,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5738431a-e281-4abf-9837-44fec6811ff4", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.init(project=PROJECT, job_type=\"generation\", anonymous=\"allow\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac08f78b-0962-4d84-b39a-21ee5e5d606b", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"LANGCHAIN_WANDB_TRACING\"] = \"true\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "539bc081-d1e3-4376-a817-23aa1d7ab2b3", + "metadata": {}, + "outputs": [], + "source": [ + "class WorldPickerTool(BaseTool):\n", + " name = \"pick_world\"\n", + " description = \"pick a virtual game world for your character or item naming\"\n", + " worlds = [\n", + " \"a mystic medieval island inhabited by intelligent and funny frogs\",\n", + " \"a modern anthill featuring a cyber-ant queen and her cyber-ant-workers\",\n", + " \"a digital world inhabited by friendly machine learning engineers\"\n", + " ]\n", + "\n", + " def _run(\n", + " self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None\n", + " ) -> str:\n", + " \"\"\"Use the tool.\"\"\"\n", + " time.sleep(1)\n", + " return random.choice(self.worlds)\n", + "\n", + " async def _arun(\n", + " self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None\n", + " ) -> str:\n", + " \"\"\"Use the tool asynchronously.\"\"\"\n", + " raise NotImplementedError(\"pick_world does not support async\")\n", + " \n", + "class NameValidatorTool(BaseTool):\n", + " name = \"validate_name\"\n", + " description = \"validate if the name is properly generated\"\n", + "\n", + " def _run(\n", + " self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None\n", + " ) -> str:\n", + " \"\"\"Use the tool.\"\"\"\n", + " time.sleep(1)\n", + " if len(query) < 20:\n", + " return f\"This is a correct name: {query}\"\n", + " else:\n", + " return f\"This name is too long. It should be shorter than 20 characters.\"\n", + "\n", + " async def _arun(\n", + " self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None\n", + " ) -> str:\n", + " \"\"\"Use the tool asynchronously.\"\"\"\n", + " raise NotImplementedError(\"validate_name does not support async\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c515ee33-1d6f-47e7-aceb-845c363eee29", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0.7)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "989407f4-0e10-4446-90d1-992c3b4c9483", + "metadata": {}, + "outputs": [], + "source": [ + "tools = [WorldPickerTool(), NameValidatorTool()]\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d4bd42d-9c95-4e02-8679-99ca43d0aa71", + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\n", + " \"Find a virtual game world for me and imagine the name of a hero in that world\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dbb5ea87-a9b9-462f-80bf-b56d681dec8c", + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\n", + " \"Find a virtual game world for me and imagine the name of a jewel in that world\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d101fcd-cd7d-4ede-ad95-412c1cd72e46", + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\n", + " \"Find a virtual game world for me and imagine the name of food in that world\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "486c688c-2ca2-4fe5-8f22-afd194b3e34d", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18b6af79-9de7-4bfd-b8ea-6b4f2b405d0a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "643f6295", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93462bd0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dlai/04_train_llm.ipynb b/dlai/04_train_llm.ipynb new file mode 100644 index 00000000..e08cdd66 --- /dev/null +++ b/dlai/04_train_llm.ipynb @@ -0,0 +1,523 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1dfae479-9399-492d-acaa-d9751615ee86", + "metadata": { + "tags": [] + }, + "source": [ + "# Finetuning a language model\n", + "Let's see how to finetune a language model to generate character backstories using HuggingFace Trainer with wandb integration. We'll use a tiny language model (`TinyStories-33M`) due to resource constraints, but the lessons you learn here should be applicable to large models too!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a1f0e67f", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "from datasets import load_dataset\n", + "from transformers import AutoModelForCausalLM\n", + "from transformers import Trainer, TrainingArguments\n", + "\n", + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f79c25e3-5f18-4457-84e1-ed2c0d262222", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcapecape\u001b[0m (\u001b[33mdeeplearning-ai-temp\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "wandb.login(anonymous=\"allow\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2286ae41-213d-480d-a4ba-8c4e2e1c4771", + "metadata": {}, + "outputs": [], + "source": [ + "model_checkpoint = \"roneneldan/TinyStories-33M\"" + ] + }, + { + "cell_type": "markdown", + "id": "3fd80268-c4a1-4e1a-aed3-cd5c3ab4d48f", + "metadata": {}, + "source": [ + "### Preparing data\n", + "\n", + "We'll start by loading a dataset containing Dungeons and Dragons character biographies from Huggingface. " + ] + }, + { + "cell_type": "markdown", + "id": "c9288a8e-b19b-4bd2-a72c-7dda03632282", + "metadata": {}, + "source": [ + "> You can expect to get some warning here, this is ok" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a7535b8b-d220-44e8-a56c-97e250c36596", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset parquet (/Users/tcapelle/.cache/huggingface/datasets/MohamedRashad___parquet/MohamedRashad--characters_backstories-6398ba4bb1a6e421/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b1d8315d3ae54248840650543b19d386", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n" + ] + } + ], + "source": [ + "# Let's check out one prepared example\n", + "print(tokenizer.decode(tokenized_datasets[\"train\"][900]['input_ids']))" + ] + }, + { + "cell_type": "markdown", + "id": "2e8d6b17-a63d-41f1-92cf-416064b52156", + "metadata": {}, + "source": [ + "### Training\n", + "Let's finetune a pretrained language model on our dataset using HF Transformers and their wandb integration. " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b4f131eb-979e-40f6-9e28-19756beaa8e4", + "metadata": {}, + "outputs": [], + "source": [ + "# We will train a causal (autoregressive) language model from a pretrained checkpoint\n", + "model = AutoModelForCausalLM.from_pretrained(model_checkpoint);" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7345ab23-8d12-4d4c-a39d-bb2202bff218", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "wandb version 0.15.5 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.4" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/tcapelle/work/edu/dlai/wandb/run-20230718_172033-c2lx2628" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run zany-eon-5 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/deeplearning-ai-temp/dlai-lm-tuning" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/deeplearning-ai-temp/dlai-lm-tuning/runs/c2lx2628" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Start a new wandb run\n", + "run = wandb.init(project='dlai-lm-tuning', job_type=\"training\", anonymous=\"allow\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d74ee155-3c30-4ef2-9c4d-fd8ee222c50c", + "metadata": {}, + "outputs": [], + "source": [ + "# Define training arguments\n", + "model_name = model_checkpoint.split(\"/\")[-1]\n", + "training_args = TrainingArguments(\n", + " f\"{model_name}-finetuned-characters-backstories\",\n", + " report_to=\"wandb\", # we need one line to track experiments in wandb\n", + " num_train_epochs=1,\n", + " logging_steps=1,\n", + " evaluation_strategy = \"epoch\",\n", + " learning_rate=1e-4,\n", + " weight_decay=0.01,\n", + " no_cuda=True, # force cpu use, will be renamed `use_cpu`\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "af62105f-a478-436f-88a2-5c1d78b9d20a", + "metadata": {}, + "outputs": [], + "source": [ + "# We'll use HF Trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_datasets[\"train\"],\n", + " eval_dataset=tokenized_datasets[\"test\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "816f4c88-bcf2-474a-afbc-b646f89df86c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cpu')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.accelerator.device" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "01958a56-c22a-4a27-bc71-41c59fc97f05", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [233/233 02:49, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
15.3216003.384721

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=233, training_loss=3.7527249718940308, metrics={'train_runtime': 170.973, 'train_samples_per_second': 10.861, 'train_steps_per_second': 1.363, 'total_flos': 40423258718208.0, 'train_loss': 3.7527249718940308, 'epoch': 1.0})" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Let's train!\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "7911e43f-f4ce-4855-9f68-662438af8d24", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + } + ], + "source": [ + "prefix = \"Generate Backstory based on following information Character Name: \"\n", + "\n", + "prompts = [\n", + " \"Frogger Character Race: Aarakocra Character Class: Ranger Output: \",\n", + " \"Smarty Character Race: Aasimar Character Class: Cleric Output: \",\n", + " \"Volcano Character Race: Android Character Class: Paladin Output: \",\n", + "]\n", + "\n", + "table = wandb.Table(columns=[\"prompt\", \"generation\"])\n", + "\n", + "for prompt in prompts:\n", + " input_ids = tokenizer.encode(prefix + prompt, return_tensors=\"pt\")\n", + " output = model.generate(input_ids, do_sample=True, max_new_tokens=50, top_p=0.3)\n", + " output_text = tokenizer.decode(output[0], skip_special_tokens=True)\n", + " table.add_data(prefix + prompt, output_text)\n", + " \n", + "wandb.log({'tiny_generations': table})" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3083c6a3-fdb8-44ab-a028-c0a222a2fdef", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/dlai/README.md b/dlai/README.md new file mode 100644 index 00000000..13db46a9 --- /dev/null +++ b/dlai/README.md @@ -0,0 +1,13 @@ +[![](https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-gradient.svg)](https://wandb.ai/capecape/dlai_diffusion) + +# DLAI with W&B 😎 + +We instrument various notebooks from the generative AI course with W&B to track metrics, hyperparameters, and artifacts. + +- [00_intro](00_intro.ipynb) In this notebooks we learn about using wegiths and biases! we train a simple classifier on the Sprites datasets and log the results to W&B. +- [01_diffusion_training](01_diffusion_training.ipynb) In this notebook we train a diffusion model to generate images from the Sprites dataset. We log the training metrics to W&B. We sample from the model and log the images to W&B. +- [02_diffusion_sampling](02_diffusion_sampling.ipynb) In this notebook we sample from the trained model and log the images to W&B. We compare different sampling methods and log the results. +- [03 LLM evaluation and debugging](03_llm_eval.ipynb) In this notebook we generate character names using LLMs and use W&B autologgers and Tracer to evaluate and debug our generations. +- [04 WIP](04_train_llm.ipynb) Finetunning and LLM on a character based dataset to create heros descriptions! + +The W&B dashboard: https://wandb.ai/deeplearning-ai-temp \ No newline at end of file diff --git a/dlai/data/sprite_labels_nc_1788_16x16.npy b/dlai/data/sprite_labels_nc_1788_16x16.npy new file mode 100644 index 00000000..b5eec1e2 --- /dev/null +++ b/dlai/data/sprite_labels_nc_1788_16x16.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b71222bd58b57cd99b1b92d830393e86ce215e0f69602f2c82aad1522f030ed7 +size 3576128 diff --git a/dlai/data/sprites_1788_16x16.npy b/dlai/data/sprites_1788_16x16.npy new file mode 100644 index 00000000..1055e7de --- /dev/null +++ b/dlai/data/sprites_1788_16x16.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61cf3b7e3184f57f2bc2bf5e75fbcf08ba379241f58966c62a9716ef581b2916 +size 68659328 diff --git a/dlai/data/weights/context_model.pth b/dlai/data/weights/context_model.pth new file mode 100644 index 00000000..451319cf --- /dev/null +++ b/dlai/data/weights/context_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51535239b6f3e953db8ff9631278c3e6b133a5a500780bda5092db620ca8f570 +size 5989463 diff --git a/dlai/requirements.txt b/dlai/requirements.txt new file mode 100644 index 00000000..96f321ce --- /dev/null +++ b/dlai/requirements.txt @@ -0,0 +1,198 @@ +accelerate==0.21.0 +aiohttp==3.8.4 +aiosignal==1.3.1 +anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1688651106312/work/dist +appdirs==1.4.4 +argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work +argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1649500321618/work +asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work +async-lru @ file:///home/conda/feedstock_root/build_artifacts/async-lru_1688997201545/work +async-timeout==4.0.2 +attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1683424013410/work +Babel @ file:///home/conda/feedstock_root/build_artifacts/babel_1677767029043/work +backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work +backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work +beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1680888073205/work +bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work +Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1648883617327/work +build==0.10.0 +CacheControl==0.12.14 +certifi==2023.5.7 +cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work +charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1688813409104/work +cleo==2.0.1 +click==8.1.5 +cmake==3.26.4 +contourpy==1.1.0 +crashtest==0.4.1 +cryptography==41.0.2 +cycler==0.11.0 +dataclasses-json==0.5.9 +datasets==2.13.1 +decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work +defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work +dill==0.3.6 +distlib==0.3.6 +docker-pycreds==0.4.0 +dulwich==0.21.5 +entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work +exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1688381075899/work +executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work +fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1684761244589/work/dist +filelock==3.12.2 +flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1684084314667/work/source/flit_core +fonttools==4.41.0 +frozenlist==1.4.0 +fsspec==2023.6.0 +gitdb==4.0.10 +GitPython==3.1.32 +greenlet==2.0.2 +html5lib==1.1 +huggingface-hub==0.16.4 +idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work +importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1688754491823/work +importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1689017639396/work +installer==0.7.0 +ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1620912942381/work/dist/ipykernel-5.5.5-py3-none-any.whl +ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1685727741709/work +ipython-genutils==0.2.0 +ipywidgets==8.0.7 +jaraco.classes==3.3.0 +jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work +jeepney==0.8.0 +Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work +json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1688248289187/work +jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work +jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1673559782596/work +jupyter-lsp @ file:///home/conda/feedstock_root/build_artifacts/jupyter-lsp-meta_1685453365113/work/jupyter-lsp +jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1687700988094/work +jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1686775603087/work +jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1687869799272/work +jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1673491454549/work +jupyterlab @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_1689253413907/work +jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work +jupyterlab-widgets==3.0.8 +jupyterlab_server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1686659921555/work +keyring==23.13.1 +kiwisolver==1.4.4 +langchain==0.0.232 +langsmith==0.0.5 +lit==16.0.6 +lockfile==0.12.2 +markdown-it-py==3.0.0 +MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work +marshmallow==3.19.0 +marshmallow-enum==1.5.1 +matplotlib==3.7.2 +matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work +mdurl==0.1.2 +mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1686313613819/work/dist +more-itertools==9.1.0 +mpmath==1.3.0 +msgpack==1.0.5 +multidict==6.0.4 +multiprocess==0.70.14 +mypy-extensions==1.0.0 +nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1684790896106/work +nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1687202153002/work +nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1688996247388/work +networkx==3.1 +notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1682360583588/work +numexpr==2.8.4 +numpy==1.25.1 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +openai==0.27.8 +openapi-schema-pydantic==1.2.4 +overrides @ file:///home/conda/feedstock_root/build_artifacts/overrides_1666057828264/work +packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work +pandas==2.0.3 +pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work +parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work +pathtools==0.1.2 +pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work +pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work +Pillow==10.0.0 +pkginfo==1.9.6 +pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work +platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1688739404342/work +poetry==1.5.1 +poetry-core==1.6.1 +poetry-plugin-export==1.4.0 +prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1689032443210/work +prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work +protobuf==4.23.4 +psutil==5.9.5 +ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl +pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work +pyarrow==12.0.1 +pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work +pydantic==1.10.11 +Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1681904169130/work +pyparsing==3.0.9 +pyproject_hooks==1.0.0 +pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1636110951836/work +PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work +python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work +pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1680088766131/work +PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1648757097602/work +pyzmq @ file:///croot/pyzmq_1686601365461/work +rapidfuzz==2.15.1 +regex==2023.6.3 +requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work +requests-toolbelt==1.0.0 +rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work +rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work +rich==13.4.2 +safetensors==0.3.1 +SecretStorage==3.3.3 +Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1682601222253/work +sentry-sdk==1.28.1 +setproctitle==1.3.2 +shellingham==1.5.0.post1 +six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work +smmap==5.0.0 +sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work +soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work +SQLAlchemy==2.0.18 +stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work +sympy==1.12 +tenacity==8.2.2 +terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work +tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work +tokenizers==0.13.3 +tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work +tomlkit==0.11.8 +torch==2.0.1 +torchvision==0.15.2 +tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work +tqdm==4.65.0 +traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work +transformers==4.30.2 +triton==2.0.0 +trove-classifiers==2023.7.6 +typing-inspect==0.9.0 +typing-utils @ file:///home/conda/feedstock_root/build_artifacts/typing_utils_1622899189314/work +typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1688315532570/work +tzdata==2023.3 +urllib3==1.26.16 +virtualenv==20.23.1 +wandb==0.15.5 +wandb-addons @ file:///home/darek/projects/edu/dlai/wandb-addons +wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work +webencodings==0.5.1 +websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1687789148259/work +widgetsnbextension==4.0.8 +xxhash==3.2.0 +yarl==1.9.2 +zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1689027407711/work diff --git a/dlai/utilities.py b/dlai/utilities.py new file mode 100644 index 00000000..a11ea3b8 --- /dev/null +++ b/dlai/utilities.py @@ -0,0 +1,489 @@ +import os, sys +import random +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from tqdm.auto import tqdm +from matplotlib.animation import FuncAnimation, PillowWriter +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision.utils import make_grid, save_image + + +def get_device(): + "Pick GPU if cuda is available, mps if Mac, else CPU" + if torch.cuda.is_available(): + return torch.device("cuda") + elif sys.platform == "darwin" and torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + +def _fig_bounds(x): + r = x//32 + return min(5, max(1,r)) + +def show_image(im, ax=None, figsize=None, title=None, **kwargs): + "Show a PIL or PyTorch image on `ax`." + cmap=None + # Handle pytorch axis order + if isinstance(im, torch.Tensor): + im = im.data.cpu() + if im.shape[0]<5: im=im.permute(1,2,0) + elif not isinstance(im, np.ndarray): + im=np.array(im) + # Handle 1-channel images + if im.shape[-1]==1: + cmap = "gray" + im=im[...,0] + + if figsize is None: + figsize = (_fig_bounds(im.shape[0]), _fig_bounds(im.shape[1])) + if ax is None: + _,ax = plt.subplots(figsize=figsize) + ax.imshow(im, cmap=cmap, **kwargs) + if title is not None: + ax.set_title(title) + ax.axis('off') + return ax + +class ContextUnet(nn.Module): + def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features + super(ContextUnet, self).__init__() + + # number of input channels, number of intermediate feature maps and number of classes + self.in_channels = in_channels + self.n_feat = n_feat + self.n_cfeat = n_cfeat + self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16... + + # Initialize the initial convolutional layer + self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) + + # Initialize the down-sampling path of the U-Net with two levels + self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8] + self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4] + + # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) + self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU()) + + # Embed the timestep and context labels with a one-layer fully connected neural network + self.timeembed1 = EmbedFC(1, 2*n_feat) + self.timeembed2 = EmbedFC(1, 1*n_feat) + self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat) + self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat) + + # Initialize the up-sampling path of the U-Net with three levels + self.up0 = nn.Sequential( + nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample + nn.GroupNorm(8, 2 * n_feat), # normalize + nn.ReLU(), + ) + self.up1 = UnetUp(4 * n_feat, n_feat) + self.up2 = UnetUp(2 * n_feat, n_feat) + + # Initialize the final convolutional layers to map to the same number of channels as the input image + self.out = nn.Sequential( + nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0 + nn.GroupNorm(8, n_feat), # normalize + nn.ReLU(), + nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input + ) + + def forward(self, x, t, c=None): + """ + x : (batch, n_feat, h, w) : input image + t : (batch, n_cfeat) : time step + c : (batch, n_classes) : context label + """ + # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on + + # pass the input image through the initial convolutional layer + x = self.init_conv(x) + # pass the result through the down-sampling path + down1 = self.down1(x) #[10, 256, 8, 8] + down2 = self.down2(down1) #[10, 256, 4, 4] + + # convert the feature maps to a vector and apply an activation + hiddenvec = self.to_vec(down2) + + # mask out context if context_mask == 1 + if c is None: + c = torch.zeros(x.shape[0], self.n_cfeat).to(x) + + # embed context and timestep + cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1) + temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) + cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) + temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) + #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}") + + + up1 = self.up0(hiddenvec) + up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings + up3 = self.up2(cemb2*up2 + temb2, down1) + out = self.out(torch.cat((up3, x), 1)) + return out + +class ResidualConvBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, is_res: bool = False + ) -> None: + super().__init__() + + # Check if input and output channels are the same for the residual connection + self.same_channels = in_channels == out_channels + + # Flag for whether or not to use residual connection + self.is_res = is_res + + # First convolutional layer + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 + nn.BatchNorm2d(out_channels), # Batch normalization + nn.GELU(), # GELU activation function + ) + + # Second convolutional layer + self.conv2 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 + nn.BatchNorm2d(out_channels), # Batch normalization + nn.GELU(), # GELU activation function + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # If using residual connection + if self.is_res: + # Apply first convolutional layer + x1 = self.conv1(x) + + # Apply second convolutional layer + x2 = self.conv2(x1) + + # If input and output channels are the same, add residual connection directly + if self.same_channels: + out = x + x2 + else: + # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection + shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device) + out = shortcut(x) + x2 + #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}") + + # Normalize output tensor + return out / 1.414 + + # If not using residual connection, return output of second convolutional layer + else: + x1 = self.conv1(x) + x2 = self.conv2(x1) + return x2 + + # Method to get the number of output channels for this block + def get_out_channels(self): + return self.conv2[0].out_channels + + # Method to set the number of output channels for this block + def set_out_channels(self, out_channels): + self.conv1[0].out_channels = out_channels + self.conv2[0].in_channels = out_channels + self.conv2[0].out_channels = out_channels + + + +class UnetUp(nn.Module): + def __init__(self, in_channels, out_channels): + super(UnetUp, self).__init__() + + # Create a list of layers for the upsampling block + # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers + layers = [ + nn.ConvTranspose2d(in_channels, out_channels, 2, 2), + ResidualConvBlock(out_channels, out_channels), + ResidualConvBlock(out_channels, out_channels), + ] + + # Use the layers to create a sequential model + self.model = nn.Sequential(*layers) + + def forward(self, x, skip): + # Concatenate the input tensor x with the skip connection tensor along the channel dimension + x = torch.cat((x, skip), 1) + + # Pass the concatenated tensor through the sequential model and return the output + x = self.model(x) + return x + + +class UnetDown(nn.Module): + def __init__(self, in_channels, out_channels): + super(UnetDown, self).__init__() + + # Create a list of layers for the downsampling block + # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling + layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)] + + # Use the layers to create a sequential model + self.model = nn.Sequential(*layers) + + def forward(self, x): + # Pass the input through the sequential model and return the output + return self.model(x) + +class EmbedFC(nn.Module): + def __init__(self, input_dim, emb_dim): + super(EmbedFC, self).__init__() + ''' + This class defines a generic one layer feed-forward neural network for embedding input data of + dimensionality input_dim to an embedding space of dimensionality emb_dim. + ''' + self.input_dim = input_dim + + # define the layers for the network + layers = [ + nn.Linear(input_dim, emb_dim), + nn.GELU(), + nn.Linear(emb_dim, emb_dim), + ] + + # create a PyTorch sequential model consisting of the defined layers + self.model = nn.Sequential(*layers) + + def forward(self, x): + # flatten the input tensor + x = x.view(-1, self.input_dim) + # apply the model layers to the flattened tensor + return self.model(x) + +def unorm(x): + # unity norm. results in range of [0,1] + # assume x (h,w,3) + xmax = x.max((0,1)) + xmin = x.min((0,1)) + return(x - xmin)/(xmax - xmin) + +def norm_all(store, n_t, n_s): + # runs unity norm on all timesteps of all samples + nstore = np.zeros_like(store) + for t in range(n_t): + for s in range(n_s): + nstore[t,s] = unorm(store[t,s]) + return nstore + +def norm_torch(x_all): + # runs unity norm on all timesteps of all samples + # input is (n_samples, 3,h,w), the torch image format + x = x_all.cpu().numpy() + xmax = x.max((2,3)) + xmin = x.min((2,3)) + xmax = np.expand_dims(xmax,(2,3)) + xmin = np.expand_dims(xmin,(2,3)) + nstore = (x - xmin)/(xmax - xmin) + return torch.from_numpy(nstore) + +def gen_tst_context(n_cfeat): + """ + Generate test context vectors + """ + vec = torch.tensor([ + [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing + [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing + [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing + [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing + [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing + [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing + ) + return len(vec), vec + +def plot_grid(x,n_sample,n_rows,save_dir,w): + # x:(n_sample, 3, h, w) + ncols = n_sample//n_rows + grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row. + save_image(grid, save_dir + f"run_image_w{w}.png") + print('saved image at ' + save_dir + f"run_image_w{w}.png") + return grid + +def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): + ncols = n_sample//nrows + sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w) + nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow + + # create gif of images evolving over time, based on x_gen_store + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) + def animate_diff(i, store): + print(f'gif animating frame {i} of {store.shape[0]}', end='\r') + plots = [] + for row in range(nrows): + for col in range(ncols): + axs[row, col].clear() + axs[row, col].set_xticks([]) + axs[row, col].set_yticks([]) + plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) + return plots + ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) + plt.close() + if save: + ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) + print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") + return ani + + +default_tfms = transforms.Compose([ + transforms.ToTensor(), # from [0,255] to range [0.0,1.0] + transforms.RandomHorizontalFlip(), # randomly flip and rotate + transforms.Normalize((0.5,), (0.5,)) # range [-1,1] +]) + +class CustomDataset(Dataset): + def __init__(self, sprites, slabels, transform=default_tfms, null_context=False, argmax=False): + self.sprites = sprites + if argmax: + self.slabels = np.argmax(slabels, axis=1) + else: + self.slabels = slabels + self.transform = transform + self.null_context = null_context + + @classmethod + def from_np(cls, + path, + sfilename="sprites_1788_16x16.npy", lfilename="sprite_labels_nc_1788_16x16.npy", transform=default_tfms, null_context=False, argmax=False): + sprites = np.load(Path(path)/sfilename) + slabels = np.load(Path(path)/lfilename) + return cls(sprites, slabels, transform, null_context, argmax) + + # Return the number of images in the dataset + def __len__(self): + return len(self.sprites) + + # Get the image and label at a given index + def __getitem__(self, idx): + # Return the image and label as a tuple + if self.transform: + image = self.transform(self.sprites[idx]) + if self.null_context: + label = torch.tensor(0).to(torch.int64) + else: + label = torch.tensor(self.slabels[idx]).to(torch.int64) + return (image, label) + + + def subset(self, slice_size=1000): + # return a subset of the dataset + indices = random.sample(range(len(self)), slice_size) + return CustomDataset(self.sprites[indices], self.slabels[indices], self.transform, self.null_context) + + def split(self, pct=0.2): + "split dataset into train and test" + train_size = int((1-pct)*len(self)) + test_size = len(self) - train_size + train_dataset, test_dataset = torch.utils.data.random_split(self, [train_size, test_size]) + return train_dataset, test_dataset + +def get_dataloaders(data_dir, batch_size, slice_size=None, valid_pct=0.2): + "Get train/val dataloaders for classification on sprites dataset" + dataset = CustomDataset.from_np(Path(data_dir), argmax=True) + if slice_size: + dataset = dataset.subset(slice_size) + + train_ds, valid_ds = dataset.split(valid_pct) + + train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=1) + valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=1) + + return train_dl, valid_dl + + +## diffusion functions + +def setup_ddpm(beta1, beta2, timesteps, device): + # construct DDPM noise schedule and sampling functions + b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1 + a_t = 1 - b_t + ab_t = torch.cumsum(a_t.log(), dim=0).exp() + ab_t[0] = 1 + + # helper function: perturbs an image to a specified noise level + def perturb_input(x, t, noise): + return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise + + # helper function; removes the predicted noise (but adds some noise back in to avoid collapse) + def _denoise_add_noise(x, t, pred_noise, z=None): + if z is None: + z = torch.randn_like(x) + noise = b_t.sqrt()[t] * z + mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt() + return mean + noise + + # sample with context using standard algorithm + # we make a change to the original algorithm to allow for context explicitely (the noises) + @torch.no_grad() + def sample_ddpm_context(nn_model, noises, context, save_rate=20): + # array to keep track of generated steps for plotting + intermediate = [] + pbar = tqdm(range(timesteps, 0, -1), leave=False) + for i in pbar: + pbar.set_description(f'sampling timestep {i:3d}') + + # reshape time tensor + t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device) + + # sample some random noise to inject back in. For i = 1, don't add back in noise + z = torch.randn_like(noises) if i > 1 else 0 + + eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx) + noises = _denoise_add_noise(noises, i, eps, z) + if i % save_rate==0 or i==timesteps or i<8: + intermediate.append(noises.detach().cpu().numpy()) + + intermediate = np.stack(intermediate) + return noises.clip(-1, 1), intermediate + + return perturb_input, sample_ddpm_context + + +def setup_ddim(beta1, beta2, timesteps, device): + # define sampling function for DDIM + b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1 + a_t = 1 - b_t + ab_t = torch.cumsum(a_t.log(), dim=0).exp() + ab_t[0] = 1 + # removes the noise using ddim + def denoise_ddim(x, t, t_prev, pred_noise): + ab = ab_t[t] + ab_prev = ab_t[t_prev] + + x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise) + dir_xt = (1 - ab_prev).sqrt() * pred_noise + + return x0_pred + dir_xt + + # fast sampling algorithm with context + @torch.no_grad() + def sample_ddim_context(nn_model, noises, context, n=25): + # array to keep track of generated steps for plotting + intermediate = [] + step_size = timesteps // n + pbar=tqdm(range(timesteps, 0, -step_size), leave=False) + for i in pbar: + pbar.set_description(f'sampling timestep {i:3d}') + + # reshape time tensor + t = torch.tensor([i / timesteps])[:, None, None, None].to(device) + + eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t) + noises = denoise_ddim(noises, i, i - step_size, eps) + intermediate.append(noises.detach().cpu().numpy()) + + intermediate = np.stack(intermediate) + return noises.clip(-1, 1), intermediate + + return sample_ddim_context + +def to_classes(ctx_vector): + classes = "hero,non-hero,food,spell,side-facing".split(",") + return [classes[i] for i in ctx_vector.argmax(dim=1)] \ No newline at end of file