diff --git a/requirements.txt b/requirements.txt index 139dbcf..44f42d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ diffusers==0.18.2 embedding_reader==1.5.1 fire==0.5.0 fsspec==2023.5.0 -library==0.0.0 lion_pytorch==0.0.6 matplotlib==3.7.1 multidict==6.0.4 @@ -28,4 +27,4 @@ lightning nudenet==3.0.8 invisible-watermark git+https://github.com/openai/CLIP.git - +omegaconf diff --git a/sd15_child_removal.ipynb b/sd15_child_removal.ipynb new file mode 100644 index 0000000..c0cca18 --- /dev/null +++ b/sd15_child_removal.ipynb @@ -0,0 +1,1372 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "_U-cYMtEVY2K" + }, + "source": [ + "## Create Configs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "XBgbejdDV2iq" + }, + "source": [ + "### Step 1: Choose Base Model\n", + "\n", + "Examples of `pretrained_sd_model`:\n", + "- SD v1.4: `CompVis/stable-diffusion-v1-4`\n", + "- SD v1.5: `runwayml/stable-diffusion-v1-5`\n", + "- Dreamshaper v8: `stablediffusionapi/dreamshaper-v8`\n", + "- SD v2.1: `stabilityai/stable-diffusion-2-1-base`\n", + "- WD1.5 beta3: `Birchlabs/wd-1-5-beta3-unofficial`\n", + "- SDXL: `stabilityai/stable-diffusion-xl-base-1.0`\n", + "- Juggernaut v6: `RunDiffusion/Juggernaut-XL-v6`\n", + "- PonyXL: `stablediffusionapi/pony-diffusion-v6-xl`\n", + "\n", + "If base model is v2.x, set `is_v2_model` to `true`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pretrained_sd_model = r\"/path/to/model/model.safetensors\" #@param {type: \"string\"}\n", + "is_v2_model = \"false\" #@param [\"true\", \"false\"]\n", + "is_v_prediction_model = \"false\" #@param [\"true\", \"false\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2: Choose Concept\n", + "\n", + "- `target_concept`: Targeted concept for erasing\n", + "- `surrogate_concept`: Surrogate concept for defining model generation after erasure, empty string by default" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If you want to train multiple concepts at once in a single SPM, set this to True\n", + "multiple_concepts = True #@param {type: \"boolean\"}\n", + "\n", + "target_concept = '' #@param {type: \"string\"}\n", + "surrogate_concept = '' #@param {type: \"string\"}\n", + "\n", + "# If multiple_concepts is True, the following fields will be used instead\n", + "target_concepts = ['teen','child','kid','loli','toddler','preteen','baby','newborn','girl','boy','young girl','young boy','year old','years old','little boy','little girl'] #@param {type: \"string\"}\n", + "surrogate_concepts = ['adult','adult','adult','woman','adult','adult','adult','adult','woman','man','woman','man','20 year old','20 years old','man','woman'] #@param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: SPM Settings\n", + "\n", + "- `mode`: `erase_with_la` or `erase` (for ablation)\n", + "- `dim`: default to 1, other values for ablation\n", + "- `sampling_batch_size`: indicates how many latent anchors are sampled for each iteration, default to 4, can be reduced if there's not enough VRAM\n", + "- `la_strength`: indicates the latent anchoring loss strength that balancing the erasure and preservation, default to 1000 for SD v1.4 models, should be further tuned for other base models for better performance\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mode = 'erase_with_la' #@param [\"erase_with_la\", \"erase\"]\n", + "dim = 2 #@param {type: \"number\"}\n", + "sampling_batch_size = 4 #@param {type: \"number\"}\n", + "la_strength = 250 #@param {type: \"number\"}\n", + "erasing_scale = 1.5 #@param {type: \"number\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4: Training Settings\n", + "\n", + "There are two parts for training settings, SD settings and optimization settings.\n", + "Notice that `resolution` is set to 512 for SD v1.x, 768 for SD v2.x, and 1024 for SDXL." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "resolution = 512 #@param [512, 640, 768, 896, 960, 1024]\n", + "max_denoising_steps = 20 #@param {type: \"number\"}\n", + "dynamic_resolution = \"true\" #@param [\"true\", \"false\"]\n", + "clip_skip = 1 #@param [1, 2]\n", + "\n", + "batch_size = 4 #@param {type: \"number\"}\n", + "iterations = 6000 #@param {type: \"number\"}\n", + "lr = 15e-5 #@param {type: \"number\"}\n", + "optimizer = \"AdamW8bit\" #@param {type: \"string\"}\n", + "lr_scheduler = \"cosine_with_restarts\" #@param {type: \"string\"}\n", + "lr_warmup_steps = 300 #@param {type: \"number\"}\n", + "lr_scheduler_num_cycles = 3 #@param {type: \"number\"}\n", + "save_per_steps = 750 #@param {type: \"number\"}\n", + "precision = \"bfloat16\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n", + "verbose = \"false\" #@param [\"true\", \"false\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5: (Optional) Tracking Training Details with WandB\n", + "\n", + "You can setup your wandb token to track the training details, including training statistics (e.g. losses, learning rates) and visualizations.\n", + "Your wandb token can be retrieved from https://wandb.ai/authorize ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb_token = \"\" #@param {type: \"string\"}\n", + "\n", + "prompts_to_visualize = [\"cat\", \"dog\"] #@param {type: \"string\"}\n", + "generate_num = 2 #@param {type: \"number\"}\n", + "\n", + "## track target & surrogate by default\n", + "prompts_to_visualize = [target_concept, surrogate_concept] + prompts_to_visualize\n", + "\n", + "## login with your wandb token\n", + "if wandb_token != \"\": \n", + " !wandb login {wandb_token}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 6: Generate Config Files\n", + "\n", + "Run the following code block and the config files are automatically generated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "a1pK61X_VY2L" + }, + "outputs": [], + "source": [ + "# you can custom these strings to distiguish your different exps\n", + "exp_name = f\"sd15_child_removal_v3\"\n", + "save_name = f\"{exp_name}\"\n", + "run_name = f\"{exp_name}\"\n", + "\n", + "config_file_path = f\"configs/{save_name}/config.yaml\"\n", + "prompts_file_path = f\"configs/{save_name}/prompt.yaml\"\n", + "\n", + "\n", + "config_file_content = f\"\"\"\n", + "prompts_file: \"{prompts_file_path}\"\n", + "\n", + "pretrained_model:\n", + " name_or_path: \"{pretrained_sd_model}\"\n", + " v2: {is_v2_model}\n", + " v_pred: {is_v_prediction_model}\n", + " clip_skip: {clip_skip}\n", + "\n", + "network:\n", + " rank: {dim}\n", + " alpha: 1.0\n", + "\n", + "train:\n", + " precision: {precision}\n", + " noise_scheduler: \"ddim\"\n", + " iterations: {iterations}\n", + " batch_size: {batch_size}\n", + " lr: {lr}\n", + " unet_lr: {lr}\n", + " text_encoder_lr: {0.5 * lr}\n", + " optimizer_type: \"{optimizer}\"\n", + " lr_scheduler: \"{lr_scheduler}\"\n", + " lr_warmup_steps: 500\n", + " lr_scheduler_num_cycles: 3\n", + " max_denoising_steps: {max_denoising_steps}\n", + "\n", + "save:\n", + " name: \"{save_name}\"\n", + " path: \"output/{save_name}\"\n", + " per_steps: {save_per_steps}\n", + " precision: {precision}\n", + "\n", + "logging:\n", + " use_wandb: \"false\"\n", + " interval: 500\n", + " seed: 0\n", + " generate_num: {generate_num}\n", + " run_name: \"{run_name}\"\n", + " verbose: {verbose}\n", + " prompts: {prompts_to_visualize}\n", + "\n", + "other:\n", + " use_xformers: true\n", + "\"\"\"\n", + "\n", + "import os\n", + "if not os.path.exists(f\"./configs/{save_name}\"):\n", + " os.makedirs(f\"./configs/{save_name}\")\n", + "\n", + "with open(config_file_path, \"w\") as f:\n", + " f.write(config_file_content)\n", + "\n", + "if multiple_concepts:\n", + " prompts_file_content = \"\" \n", + "\n", + " for target_concept, surrogate_concept in zip(target_concepts, surrogate_concepts):\n", + " prompts_file_content += f\"\"\"\n", + "- target: \"{target_concept}\"\n", + " positive: \"{target_concept}\"\n", + " unconditional: \"\"\n", + " neutral: \"{surrogate_concept}\"\n", + " action: \"{mode}\"\n", + " guidance_scale: \"{erasing_scale}\"\n", + " resolution: {resolution}\n", + " batch_size: {batch_size}\n", + " dynamic_resolution: {dynamic_resolution}\n", + " la_strength: {la_strength}\n", + " sampling_batch_size: {sampling_batch_size}\n", + " \"\"\"\n", + "else:\n", + " prompts_file_content = f\"\"\"\n", + "- target: \"{target_concept}\"\n", + " positive: \"{target_concept}\"\n", + " unconditional: \"\"\n", + " neutral: \"{surrogate_concept}\"\n", + " action: \"{mode}\"\n", + " guidance_scale: \"{erasing_scale}\"\n", + " resolution: {resolution}\n", + " batch_size: {batch_size}\n", + " dynamic_resolution: {dynamic_resolution}\n", + " la_strength: {la_strength}\n", + " sampling_batch_size: {sampling_batch_size}\n", + " \"\"\"\n", + "\n", + "with open(prompts_file_path, \"w\") as f:\n", + " f.write(prompts_file_content)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "wb9aX9BBVY2L" + }, + "source": [ + "## Start Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WchSOvSHVY2L" + }, + "outputs": [], + "source": [ + "!python ./train_spm.py --config_file {config_file_path}" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "RtsANZ1JbPXZ" + }, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare SD Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Na8jf1nHcEbE" + }, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers import DiffusionPipeline\n", + "import copy\n", + "import gc\n", + "\n", + "def flush():\n", + " torch.cuda.empty_cache()\n", + " gc.collect()\n", + "\n", + "flush()\n", + "\n", + "pipe = DiffusionPipeline.from_pretrained(\n", + " \"CompVis/stable-diffusion-v1-4\",\n", + " custom_pipeline=\"lpw_stable_diffusion\",\n", + " torch_dtype=torch.float16,\n", + " local_files_only=True,\n", + ")\n", + "\n", + "pipe = pipe.to(\"cuda\")\n", + "pipe.enable_xformers_memory_efficient_attention()\n", + "orig_unet = copy.deepcopy(pipe.unet)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generation Configs\n", + "\n", + "prompt = \"mickey mouse\" #@param {\"type\": \"string\"}\n", + "\n", + "negative_prompt = \"bad anatomy,watermark,extra digit,signature,worst quality,jpeg artifacts,normal quality,low quality,long neck,lowres,error,blurry,missing fingers,fewer digits,missing arms,text,cropped,Humpbacked ,bad hands,username\" #@param {\"type\": \"string\"}\n", + "width = 512 #@param {type: \"number\"}\n", + "height = 640 #@param {type: \"number\"}\n", + "steps = 20 #@param {type:\"slider\", min:1, max:50, step:1}\n", + "cfg_scale = 7.5 #@param {type:\"slider\", min:1, max:16, step:0.5}\n", + "sample_cnt = 2 #@param {type:\"number\"}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Case A: Generate with Single SPM Applied\n", + "\n", + "Here you can set `spm_paths` to trained SPM paths to compare generation behaviours with different SPM applied on SD.\n", + "\n", + "#(`spm_paths`) * `sample_cnt` samples will be displayed.\n", + "\n", + "**Notice that the SPM applied here DOES NOT activates its Facilitated Transport mechanism. For full feature, you may use `infer_spm.py`.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 405, + "referenced_widgets": [ + "46d46c9b768341b2b23dddfa76ac5c74", + "a041d80e6ce245299c737e672e6a4347", + "0961d6e5293c4c0e81b2c3adcbbabc03", + "69ebb96ba4e4462b853ecf905e87acb8", + "51be8d3e42b64f96b745fb79d13f9839", + "6a78d4147bb54a548d29645ad87576a1", + "631f52cb6f964a0daf1e41e3648e34f9", + "0142c35bb9a64938988b88d51f614f7c", + "3f49723d4a784396b2a2d90016c22c4d", + "13823035b9a741a9847db560450dcc23", + "8dfbe14085e245ba912ec56ec65f8e3e", + "e21cff947a6540c6a5a297972a540316", + "a945f60371ec4fc4b01c7dc5e9248892", + "a319221d44e543699da77217c6068db1", + "6a985be0438e471eb512298bb353ee76", + "5398915a300d4e469463679dba2aa476", + "b45e5f92f9c9470691388817aea51fc1", + "d650226eb3dd4c6fb8a32f25f232898f", + "c2ac3475d4d04457b1243771d08bb5a2", + "aa3119812f3042dc9f48c1c42f9c57c7", + "8be625d570554e5c9725b9c2ca9aed7c", + "8468f0760dcf4f399169038df5b12dfc" + ] + }, + "id": "CICMFcxefWl_", + "outputId": "83ef246d-590a-4fe3-b348-fc98ff69c188" + }, + "outputs": [], + "source": [ + "# SPM Comparison\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import random\n", + "\n", + "spm_paths = [\n", + " \"output/snoopy/snoopy_last.safetensors\",\n", + " \"output/pikachu/pikachu_last.safetensors\",\n", + "]\n", + "random_seeds = [random.randint(0, 2**32 - 1) for _ in range(sample_cnt)]\n", + "\n", + "# -------\n", + "# w/o SPM\n", + "\n", + "pipe.unet = orig_unet\n", + "\n", + "orig_samples = [pipe.text2img(\n", + " prompt,\n", + " negative_prompt=negative_prompt,\n", + " width=width,\n", + " height=height,\n", + " num_inference_steps=steps,\n", + " guidance_scale=cfg_scale,\n", + " generator=torch.manual_seed(random_seed),\n", + " ).images[0]\n", + " for random_seed in random_seeds\n", + "]\n", + "\n", + "# -------\n", + "# w/ SPM\n", + "\n", + "spms_samples = []\n", + "\n", + "for spm_path in spm_paths:\n", + " pipe.load_lora_weights(spm_path)\n", + " lora_unet = copy.deepcopy(pipe.unet)\n", + " pipe.unet = lora_unet\n", + "\n", + " spm_samples = [pipe.text2img(\n", + " prompt,\n", + " negative_prompt=negative_prompt,\n", + " width=width,\n", + " height=height,\n", + " num_inference_steps=steps,\n", + " guidance_scale=cfg_scale,\n", + " generator=torch.manual_seed(random_seed),\n", + " ).images[0]\n", + " for random_seed in random_seeds\n", + " ]\n", + " spms_samples.append(spm_samples)\n", + "# ---\n", + "spm_cnt = len(spm_paths)\n", + "fig, ax = plt.subplots(spm_cnt+1, sample_cnt, figsize=(17, 2*(spm_cnt+1)))\n", + "\n", + "for i in range(sample_cnt):\n", + " ax[0, i].imshow(orig_samples[i])\n", + " ax[0, i].axis('off')\n", + "\n", + "for n in range(spm_cnt):\n", + " for i in range(sample_cnt):\n", + " ax[n+1, i].imshow(spms_samples[n][i])\n", + " ax[n+1, i].axis('off')\n", + "\n", + "plt.subplots_adjust(wspace=0.1, hspace=0.1)\n", + "\n", + "plt.show()\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Case B: Generate with Multiple SPMs Applied\n", + "\n", + "**Notice that the SPM applied here DOES NOT activates its Facilitated Transport mechanism. For full feature, you may use `infer_spm.py`.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Multi SPMs Generation\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import random\n", + "import torch\n", + "from diffusers import DiffusionPipeline\n", + "import copy\n", + "import gc\n", + "\n", + "from src.models.merge_spm import merge_to_sd_model\n", + "\n", + "spm_path_1 = \"output/snoopy/snoopy_last.safetensors\"\n", + "spm_path_2 = \"output/pikachu/pikachu_last.safetensors\"\n", + "ratios = [1.0, 1.0]\n", + "\n", + "prompts = ['snoopy', 'pikachu', 'donald duck', 'woman', '']\n", + "\n", + "\n", + "random_seeds = [random.randint(0, 2**32 - 1) for _ in range(sample_cnt * len(prompts))]\n", + "\n", + "repeated_prompts = []\n", + "_ = [repeated_prompts.extend([prompt] * sample_cnt) for prompt in prompts]\n", + "\n", + "def flush():\n", + " torch.cuda.empty_cache()\n", + " gc.collect()\n", + "\n", + "flush()\n", + "\n", + "samples = []\n", + "\n", + "# -------\n", + "# Original\n", + "\n", + "pipe = DiffusionPipeline.from_pretrained(\n", + " \"CompVis/stable-diffusion-v1-4\",\n", + " custom_pipeline=\"lpw_stable_diffusion\",\n", + " torch_dtype=torch.float32,\n", + " local_files_only=True,\n", + ")\n", + "\n", + "pipe = pipe.to(\"cuda\")\n", + "pipe.enable_xformers_memory_efficient_attention()\n", + "orig_unet = copy.deepcopy(pipe.unet)\n", + "\n", + "orig_samples = [pipe.text2img(\n", + " prompt,\n", + " negative_prompt=negative_prompt,\n", + " width=width,\n", + " height=height,\n", + " num_inference_steps=steps,\n", + " guidance_scale=cfg_scale,\n", + " generator=torch.manual_seed(random_seed),\n", + " ).images[0]\n", + " for (prompt, random_seed) in zip(repeated_prompts, random_seeds)\n", + "]\n", + "samples.append(orig_samples)\n", + "\n", + "# -------\n", + "# Applying spm_1 / spm_2\n", + "\n", + "for spm_path in [spm_path_1, spm_path_2]:\n", + " pipe.load_lora_weights(spm_path)\n", + " lora_unet = copy.deepcopy(pipe.unet)\n", + " pipe.unet = lora_unet\n", + "\n", + " spm_samples = [pipe.text2img(\n", + " prompt,\n", + " negative_prompt=negative_prompt,\n", + " width=width,\n", + " height=height,\n", + " num_inference_steps=steps,\n", + " guidance_scale=cfg_scale,\n", + " generator=torch.manual_seed(random_seed),\n", + " ).images[0]\n", + " for (prompt, random_seed) in zip(repeated_prompts, random_seeds)\n", + " ]\n", + " samples.append(spm_samples)\n", + "\n", + "# -------\n", + "# Applying spm_1 & spm_2\n", + "\n", + "pipe.unet = orig_unet\n", + "merge_to_sd_model(pipe.text_encoder, pipe.unet, [spm_path_1, spm_path_2], ratios)\n", + "\n", + "spm_samples = [pipe.text2img(\n", + " prompt,\n", + " negative_prompt=negative_prompt,\n", + " width=width,\n", + " height=height,\n", + " num_inference_steps=steps,\n", + " guidance_scale=cfg_scale,\n", + " generator=torch.manual_seed(random_seed),\n", + " ).images[0]\n", + " for (prompt, random_seed) in zip(repeated_prompts, random_seeds)\n", + "]\n", + "samples.append(spm_samples)\n", + "\n", + "# -------\n", + "# Visualize\n", + "\n", + "fig, ax = plt.subplots(4, sample_cnt * len(prompts), figsize=(17, 2*len(samples)))\n", + "\n", + "def format_yaxis(ax, label):\n", + " ax.axis('on')\n", + " ax.set_ylabel(label)\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.set_yticklabels([])\n", + " ax.tick_params(length=0)\n", + " for spine in ['top', 'bottom', 'left', 'right']:\n", + " ax.spines[spine].set_visible(False)\n", + "\n", + "get_exp_name = lambda x: x[x.rfind('/') + 1: x.rfind('_last.')]\n", + "ylabels = ['Original', get_exp_name(spm_path_1), get_exp_name(spm_path_2), 'Both']\n", + "for n in range(len(samples)):\n", + " for i, img in enumerate(samples[n]):\n", + " ax[n, i].imshow(img)\n", + " ax[n, i].axis('off')\n", + " if n == 0:\n", + " ax[n, i].set_title(f\"{repeated_prompts[i]}_{i % sample_cnt + 1}\", va='top', fontsize='small')\n", + " if i == 0:\n", + " format_yaxis(ax[n, i], ylabels[n])\n", + "\n", + "\n", + "plt.subplots_adjust(wspace=0.1, hspace=0.1)\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "K2UMzQb1kak-" + }, + "outputs": [], + "source": [ + "# Release memory (optional)\n", + "\n", + "del pipe\n", + "del orig_unet\n", + "del lora_unet\n", + "flush()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "0142c35bb9a64938988b88d51f614f7c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0961d6e5293c4c0e81b2c3adcbbabc03": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0142c35bb9a64938988b88d51f614f7c", + "max": 20, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_3f49723d4a784396b2a2d90016c22c4d", + "value": 20 + } + }, + "13823035b9a741a9847db560450dcc23": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3f49723d4a784396b2a2d90016c22c4d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "46d46c9b768341b2b23dddfa76ac5c74": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a041d80e6ce245299c737e672e6a4347", + "IPY_MODEL_0961d6e5293c4c0e81b2c3adcbbabc03", + "IPY_MODEL_69ebb96ba4e4462b853ecf905e87acb8" + ], + "layout": "IPY_MODEL_51be8d3e42b64f96b745fb79d13f9839" + } + }, + "51be8d3e42b64f96b745fb79d13f9839": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5398915a300d4e469463679dba2aa476": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "631f52cb6f964a0daf1e41e3648e34f9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "69ebb96ba4e4462b853ecf905e87acb8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_13823035b9a741a9847db560450dcc23", + "placeholder": "​", + "style": "IPY_MODEL_8dfbe14085e245ba912ec56ec65f8e3e", + "value": " 20/20 [00:04<00:00, 5.10it/s]" + } + }, + "6a78d4147bb54a548d29645ad87576a1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6a985be0438e471eb512298bb353ee76": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8be625d570554e5c9725b9c2ca9aed7c", + "placeholder": "​", + "style": "IPY_MODEL_8468f0760dcf4f399169038df5b12dfc", + "value": " 20/20 [00:04<00:00, 4.76it/s]" + } + }, + "8468f0760dcf4f399169038df5b12dfc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8be625d570554e5c9725b9c2ca9aed7c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8dfbe14085e245ba912ec56ec65f8e3e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a041d80e6ce245299c737e672e6a4347": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6a78d4147bb54a548d29645ad87576a1", + "placeholder": "​", + "style": "IPY_MODEL_631f52cb6f964a0daf1e41e3648e34f9", + "value": "100%" + } + }, + "a319221d44e543699da77217c6068db1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c2ac3475d4d04457b1243771d08bb5a2", + "max": 20, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_aa3119812f3042dc9f48c1c42f9c57c7", + "value": 20 + } + }, + "a945f60371ec4fc4b01c7dc5e9248892": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b45e5f92f9c9470691388817aea51fc1", + "placeholder": "​", + "style": "IPY_MODEL_d650226eb3dd4c6fb8a32f25f232898f", + "value": "100%" + } + }, + "aa3119812f3042dc9f48c1c42f9c57c7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b45e5f92f9c9470691388817aea51fc1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c2ac3475d4d04457b1243771d08bb5a2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d650226eb3dd4c6fb8a32f25f232898f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e21cff947a6540c6a5a297972a540316": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a945f60371ec4fc4b01c7dc5e9248892", + "IPY_MODEL_a319221d44e543699da77217c6068db1", + "IPY_MODEL_6a985be0438e471eb512298bb353ee76" + ], + "layout": "IPY_MODEL_5398915a300d4e469463679dba2aa476" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/sdxl_child_removal.ipynb b/sdxl_child_removal.ipynb new file mode 100644 index 0000000..ab58bcd --- /dev/null +++ b/sdxl_child_removal.ipynb @@ -0,0 +1,31357 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "_U-cYMtEVY2K" + }, + "source": [ + "## Create Configs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "XBgbejdDV2iq" + }, + "source": [ + "### Step 1: Choose Base Model\n", + "\n", + "Examples of `pretrained_sd_model`:\n", + "- SD v1.4: `CompVis/stable-diffusion-v1-4`\n", + "- SD v1.5: `runwayml/stable-diffusion-v1-5`\n", + "- Dreamshaper v8: `stablediffusionapi/dreamshaper-v8`\n", + "- SD v2.1: `stabilityai/stable-diffusion-2-1-base`\n", + "- WD1.5 beta3: `Birchlabs/wd-1-5-beta3-unofficial`\n", + "- SDXL: `stabilityai/stable-diffusion-xl-base-1.0`\n", + "- Juggernaut v6: `RunDiffusion/Juggernaut-XL-v6`\n", + "- PonyXL: `stablediffusionapi/pony-diffusion-v6-xl`\n", + "\n", + "If base model is v2.x, set `is_v2_model` to `true`." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "pretrained_sd_model = r\"/workspace/SPM/URPMXL-V6.safetensors\" #@param {type: \"string\"}\n", + "is_v2_model = \"false\" #@param [\"true\", \"false\"]\n", + "is_v_prediction_model = \"false\" #@param [\"true\", \"false\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2: Choose Concept\n", + "\n", + "- `target_concept`: Targeted concept for erasing\n", + "- `surrogate_concept`: Surrogate concept for defining model generation after erasure, empty string by default" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# If you want to train multiple concepts at once in a single SPM, set this to True\n", + "multiple_concepts = True #@param {type: \"boolean\"}\n", + "\n", + "target_concept = '' #@param {type: \"string\"}\n", + "surrogate_concept = '' #@param {type: \"string\"}\n", + "\n", + "# If multiple_concepts is True, the following fields will be used instead\n", + "target_concepts = ['teen','child','kid','loli','toddler','preteen','baby','newborn','girl','boy','young girl','young boy','year old','years old','little boy','little girl'] #@param {type: \"string\"}\n", + "surrogate_concepts = ['adult','adult','adult','woman','adult','adult','adult','adult','woman','man','woman','man','20 year old','20 years old','man','woman'] #@param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3: SPM Settings\n", + "\n", + "- `mode`: `erase_with_la` or `erase` (for ablation)\n", + "- `dim`: default to 1, other values for ablation\n", + "- `sampling_batch_size`: indicates how many latent anchors are sampled for each iteration, default to 4, can be reduced if there's not enough VRAM\n", + "- `la_strength`: indicates the latent anchoring loss strength that balancing the erasure and preservation, default to 1000 for SD v1.4 models, should be further tuned for other base models for better performance\n" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "mode = 'erase_with_la' #@param [\"erase_with_la\", \"erase\"]\n", + "dim = 2 #@param {type: \"number\"}\n", + "sampling_batch_size = 1 #@param {type: \"number\"}\n", + "la_strength = 400 #@param {type: \"number\"}\n", + "erasing_scale = 2.0 #@param {type: \"number\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4: Training Settings\n", + "\n", + "There are two parts for training settings, SD settings and optimization settings.\n", + "Notice that `resolution` is set to 512 for SD v1.x, 768 for SD v2.x, and 1024 for SDXL." + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "resolution = 1024 #@param [512, 640, 768, 896, 960, 1024]\n", + "max_denoising_steps = 20 #@param {type: \"number\"}\n", + "dynamic_resolution = \"true\" #@param [\"true\", \"false\"]\n", + "clip_skip = 2 #@param [1, 2]\n", + "\n", + "batch_size = 1 #@param {type: \"number\"}\n", + "iterations = 5000 #@param {type: \"number\"}\n", + "lr = 15e-5 #@param {type: \"number\"}\n", + "optimizer = \"AdamW8bit\" #@param {type: \"string\"}\n", + "lr_scheduler = \"cosine_with_restarts\" #@param {type: \"string\"}\n", + "lr_warmup_steps = 300 #@param {type: \"number\"}\n", + "lr_scheduler_num_cycles = 3 #@param {type: \"number\"}\n", + "save_per_steps = 750 #@param {type: \"number\"}\n", + "precision = \"bfloat16\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n", + "verbose = \"false\" #@param [\"true\", \"false\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5: (Optional) Tracking Training Details with WandB\n", + "\n", + "You can setup your wandb token to track the training details, including training statistics (e.g. losses, learning rates) and visualizations.\n", + "Your wandb token can be retrieved from https://wandb.ai/authorize ." + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "wandb_token = \"\" #@param {type: \"string\"}\n", + "\n", + "prompts_to_visualize = [\"cat\", \"dog\"] #@param {type: \"string\"}\n", + "generate_num = 2 #@param {type: \"number\"}\n", + "\n", + "## track target & surrogate by default\n", + "prompts_to_visualize = [target_concept, surrogate_concept] + prompts_to_visualize\n", + "\n", + "## login with your wandb token\n", + "if wandb_token != \"\": \n", + " !wandb login {wandb_token}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 6: Generate Config Files\n", + "\n", + "Run the following code block and the config files are automatically generated." + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": { + "cellView": "form", + "id": "a1pK61X_VY2L" + }, + "outputs": [], + "source": [ + "# you can custom these strings to distiguish your different exps\n", + "exp_name = f\"sdxl_child_removal_v2\"\n", + "save_name = f\"{exp_name}\"\n", + "run_name = f\"{exp_name}\"\n", + "\n", + "config_file_path = f\"configs/{save_name}/config.yaml\"\n", + "prompts_file_path = f\"configs/{save_name}/prompt.yaml\"\n", + "\n", + "\n", + "config_file_content = f\"\"\"\n", + "prompts_file: \"{prompts_file_path}\"\n", + "\n", + "pretrained_model:\n", + " name_or_path: \"{pretrained_sd_model}\"\n", + " v2: {is_v2_model}\n", + " v_pred: {is_v_prediction_model}\n", + " clip_skip: {clip_skip}\n", + "\n", + "network:\n", + " rank: {dim}\n", + " alpha: 1.0\n", + "\n", + "train:\n", + " precision: {precision}\n", + " noise_scheduler: \"ddim\"\n", + " iterations: {iterations}\n", + " batch_size: {batch_size}\n", + " lr: {lr}\n", + " unet_lr: {lr}\n", + " text_encoder_lr: {0.5 * lr}\n", + " optimizer_type: \"{optimizer}\"\n", + " lr_scheduler: \"{lr_scheduler}\"\n", + " lr_warmup_steps: 500\n", + " lr_scheduler_num_cycles: 3\n", + " max_denoising_steps: {max_denoising_steps}\n", + "\n", + "save:\n", + " name: \"{save_name}\"\n", + " path: \"output/{save_name}\"\n", + " per_steps: {save_per_steps}\n", + " precision: {precision}\n", + "\n", + "logging:\n", + " use_wandb: \"false\"\n", + " interval: 500\n", + " seed: 0\n", + " generate_num: {generate_num}\n", + " run_name: \"{run_name}\"\n", + " verbose: {verbose}\n", + " prompts: {prompts_to_visualize}\n", + "\n", + "other:\n", + " use_xformers: true\n", + "\"\"\"\n", + "\n", + "import os\n", + "if not os.path.exists(f\"./configs/{save_name}\"):\n", + " os.makedirs(f\"./configs/{save_name}\")\n", + "\n", + "with open(config_file_path, \"w\") as f:\n", + " f.write(config_file_content)\n", + "\n", + "if multiple_concepts:\n", + " prompts_file_content = \"\" \n", + "\n", + " for target_concept, surrogate_concept in zip(target_concepts, surrogate_concepts):\n", + " prompts_file_content += f\"\"\"\n", + "- target: \"{target_concept}\"\n", + " positive: \"{target_concept}\"\n", + " unconditional: \"\"\n", + " neutral: \"{surrogate_concept}\"\n", + " action: \"{mode}\"\n", + " guidance_scale: \"{erasing_scale}\"\n", + " resolution: {resolution}\n", + " batch_size: {batch_size}\n", + " dynamic_resolution: {dynamic_resolution}\n", + " la_strength: {la_strength}\n", + " sampling_batch_size: {sampling_batch_size}\n", + " \"\"\"\n", + "else:\n", + " prompts_file_content = f\"\"\"\n", + "- target: \"{target_concept}\"\n", + " positive: \"{target_concept}\"\n", + " unconditional: \"\"\n", + " neutral: \"{surrogate_concept}\"\n", + " action: \"{mode}\"\n", + " guidance_scale: \"{erasing_scale}\"\n", + " resolution: {resolution}\n", + " batch_size: {batch_size}\n", + " dynamic_resolution: {dynamic_resolution}\n", + " la_strength: {la_strength}\n", + " sampling_batch_size: {sampling_batch_size}\n", + " \"\"\"\n", + "\n", + "with open(prompts_file_path, \"w\") as f:\n", + " f.write(prompts_file_content)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "wb9aX9BBVY2L" + }, + "source": [ + "## Start Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WchSOvSHVY2L" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/SPM/./train_spm_xl_mem_reduce.py:41: PydanticDeprecatedSince20: The `json` method is deprecated; use `model_dump_json` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/\n", + " \"prompts\": \",\".join([prompt.json() for prompt in prompts]),\n", + "/workspace/SPM/./train_spm_xl_mem_reduce.py:42: PydanticDeprecatedSince20: The `json` method is deprecated; use `model_dump_json` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/\n", + " \"config\": config.json(),\n", + "Loading local checkpoint from /workspace/SPM/URPMXL-V6.safetensors\n", + "Fetching 17 files: 100%|█████████████████████| 17/17 [00:00<00:00, 42747.70it/s]\n", + "Loading pipeline components...: 0%| | 0/7 [00:00