diff --git a/altdiffusionTest.ipynb b/altdiffusionTest.ipynb new file mode 100644 index 00000000..39ed68af --- /dev/null +++ b/altdiffusionTest.ipynb @@ -0,0 +1,503 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "488083be", + "metadata": { + "execution": { + "iopub.execute_input": "2025-06-08T11:42:32.133030Z", + "iopub.status.busy": "2025-06-08T11:42:32.132825Z", + "iopub.status.idle": "2025-06-08T11:44:21.234306Z", + "shell.execute_reply": "2025-06-08T11:44:21.233302Z" + }, + "papermill": { + "duration": 109.105469, + "end_time": "2025-06-08T11:44:21.235868", + "exception": false, + "start_time": "2025-06-08T11:42:32.130399", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch_lightning==1.9.5\r\n", + " Downloading pytorch_lightning-1.9.5-py3-none-any.whl.metadata (23 kB)\r\n", + "Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (1.26.4)\r\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (2.6.0+cu124)\r\n", + "Requirement already satisfied: tqdm>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (4.67.1)\r\n", + "Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (6.0.2)\r\n", + "Requirement already satisfied: fsspec>2021.06.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (2025.3.2)\r\n", + "Requirement already satisfied: torchmetrics>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (1.7.1)\r\n", + "Requirement already satisfied: packaging>=17.1 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (25.0)\r\n", + "Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (4.13.2)\r\n", + "Requirement already satisfied: lightning-utilities>=0.6.0.post0 in /usr/local/lib/python3.11/dist-packages (from pytorch_lightning==1.9.5) (0.14.3)\r\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (3.11.18)\r\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities>=0.6.0.post0->pytorch_lightning==1.9.5) (75.2.0)\r\n", + "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17.2->pytorch_lightning==1.9.5) (1.3.8)\r\n", + "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17.2->pytorch_lightning==1.9.5) (1.2.4)\r\n", + "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17.2->pytorch_lightning==1.9.5) (0.1.1)\r\n", + "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17.2->pytorch_lightning==1.9.5) (2025.1.0)\r\n", + "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17.2->pytorch_lightning==1.9.5) (2022.1.0)\r\n", + "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy>=1.17.2->pytorch_lightning==1.9.5) (2.4.1)\r\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (3.18.0)\r\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (3.4.2)\r\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (3.1.6)\r\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (12.4.127)\r\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (12.4.127)\r\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (12.4.127)\r\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (0.6.2)\r\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (2.21.5)\r\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (12.4.127)\r\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=1.10.0->pytorch_lightning==1.9.5)\r\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (3.2.0)\r\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->pytorch_lightning==1.9.5) (1.13.1)\r\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.10.0->pytorch_lightning==1.9.5) (1.3.0)\r\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (2.6.1)\r\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (1.3.2)\r\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (25.3.0)\r\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (1.6.0)\r\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (6.4.3)\r\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (0.3.1)\r\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (1.20.0)\r\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.10.0->pytorch_lightning==1.9.5) (3.0.2)\r\n", + "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy>=1.17.2->pytorch_lightning==1.9.5) (2024.2.0)\r\n", + "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy>=1.17.2->pytorch_lightning==1.9.5) (2022.1.0)\r\n", + "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy>=1.17.2->pytorch_lightning==1.9.5) (1.3.0)\r\n", + "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy>=1.17.2->pytorch_lightning==1.9.5) (2024.2.0)\r\n", + "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy>=1.17.2->pytorch_lightning==1.9.5) (2024.2.0)\r\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning==1.9.5) (3.10)\r\n", + "Downloading pytorch_lightning-1.9.5-py3-none-any.whl (829 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m829.5/829.5 kB\u001b[0m \u001b[31m18.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m81.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, pytorch_lightning\r\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\r\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.9.41\r\n", + " Uninstalling nvidia-nvjitlink-cu12-12.9.41:\r\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.9.41\r\n", + " Attempting uninstall: nvidia-curand-cu12\r\n", + " Found existing installation: nvidia-curand-cu12 10.3.10.19\r\n", + " Uninstalling nvidia-curand-cu12-10.3.10.19:\r\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.10.19\r\n", + " Attempting uninstall: nvidia-cufft-cu12\r\n", + " Found existing installation: nvidia-cufft-cu12 11.4.0.6\r\n", + " Uninstalling nvidia-cufft-cu12-11.4.0.6:\r\n", + " Successfully uninstalled nvidia-cufft-cu12-11.4.0.6\r\n", + " Attempting uninstall: nvidia-cublas-cu12\r\n", + " Found existing installation: nvidia-cublas-cu12 12.9.0.13\r\n", + " Uninstalling nvidia-cublas-cu12-12.9.0.13:\r\n", + " Successfully uninstalled nvidia-cublas-cu12-12.9.0.13\r\n", + " Attempting uninstall: nvidia-cusparse-cu12\r\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.9.5\r\n", + " Uninstalling nvidia-cusparse-cu12-12.5.9.5:\r\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.9.5\r\n", + " Attempting uninstall: nvidia-cudnn-cu12\r\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\r\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\r\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\r\n", + " Attempting uninstall: nvidia-cusolver-cu12\r\n", + " Found existing installation: nvidia-cusolver-cu12 11.7.4.40\r\n", + " Uninstalling nvidia-cusolver-cu12-11.7.4.40:\r\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.7.4.40\r\n", + " Attempting uninstall: pytorch_lightning\r\n", + " Found existing installation: pytorch-lightning 2.5.1.post0\r\n", + " Uninstalling pytorch-lightning-2.5.1.post0:\r\n", + " Successfully uninstalled pytorch-lightning-2.5.1.post0\r\n", + "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch_lightning-1.9.5\r\n", + "Cloning into 'FlagAI'...\r\n", + "remote: Enumerating objects: 10073, done.\u001b[K\r\n", + "remote: Counting objects: 100% (3162/3162), done.\u001b[K\r\n", + "remote: Compressing objects: 100% (1121/1121), done.\u001b[K\r\n", + "remote: Total 10073 (delta 2154), reused 2061 (delta 2040), pack-reused 6911 (from 4)\u001b[K\r\n", + "Receiving objects: 100% (10073/10073), 122.26 MiB | 30.48 MiB/s, done.\r\n", + "Resolving deltas: 100% (7143/7143), done.\r\n", + "Collecting git+https://github.com/whybfq/FlagAI.git\r\n", + " Cloning https://github.com/whybfq/FlagAI.git to /tmp/pip-req-build-one6lf4v\r\n", + " Running command git clone --filter=blob:none --quiet https://github.com/whybfq/FlagAI.git /tmp/pip-req-build-one6lf4v\r\n", + " Resolved https://github.com/whybfq/FlagAI.git to commit e2d4a9cb753e0741c439ed70f9b977d9076a1505\r\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\r\n", + "Requirement already satisfied: nltk>=3.6.7 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (3.9.1)\r\n", + "Requirement already satisfied: sentencepiece>=0.1.96 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (0.2.0)\r\n", + "Requirement already satisfied: boto3>=1.17.32 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (1.38.11)\r\n", + "Requirement already satisfied: pandas>=1.3.5 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (2.2.3)\r\n", + "Requirement already satisfied: jieba>=0.42.1 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (0.42.1)\r\n", + "Requirement already satisfied: scikit-learn>=1.0.2 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (1.2.2)\r\n", + "Requirement already satisfied: tensorboard>=2.9.0 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (2.18.0)\r\n", + "Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (4.51.3)\r\n", + "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (3.6.0)\r\n", + "Requirement already satisfied: setuptools>=66.0.0 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (75.2.0)\r\n", + "Requirement already satisfied: protobuf>=3.19.6 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (3.20.3)\r\n", + "Collecting ftfy (from flagai==1.8.5)\r\n", + " Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)\r\n", + "Requirement already satisfied: Pillow>=9.3.0 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (11.1.0)\r\n", + "Requirement already satisfied: einops>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (0.8.1)\r\n", + "Requirement already satisfied: diffusers>=0.7.2 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (0.32.2)\r\n", + "Requirement already satisfied: pytorch-lightning>=1.6.5 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (1.9.5)\r\n", + "Collecting taming-transformers-rom1504==0.0.6 (from flagai==1.8.5)\r\n", + " Downloading taming_transformers_rom1504-0.0.6-py3-none-any.whl.metadata (406 bytes)\r\n", + "Collecting rouge-score (from flagai==1.8.5)\r\n", + " Downloading rouge_score-0.1.2.tar.gz (17 kB)\r\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\r\n", + "Collecting sacrebleu>=2.3.1 (from flagai==1.8.5)\r\n", + " Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m51.8/51.8 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hCollecting jsonlines (from flagai==1.8.5)\r\n", + " Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)\r\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (1.5.2)\r\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (6.0.2)\r\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (0.5.3)\r\n", + "Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from flagai==1.8.5) (1.0.15)\r\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2.6.0+cu124)\r\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from taming-transformers-rom1504==0.0.6->flagai==1.8.5) (0.21.0+cu124)\r\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from taming-transformers-rom1504==0.0.6->flagai==1.8.5) (1.26.4)\r\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from taming-transformers-rom1504==0.0.6->flagai==1.8.5) (4.67.1)\r\n", + "Requirement already satisfied: omegaconf>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2.3.0)\r\n", + "Requirement already satisfied: botocore<1.39.0,>=1.38.11 in /usr/local/lib/python3.11/dist-packages (from boto3>=1.17.32->flagai==1.8.5) (1.38.11)\r\n", + "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from boto3>=1.17.32->flagai==1.8.5) (1.0.1)\r\n", + "Requirement already satisfied: s3transfer<0.13.0,>=0.12.0 in /usr/local/lib/python3.11/dist-packages (from boto3>=1.17.32->flagai==1.8.5) (0.12.0)\r\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (3.18.0)\r\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (19.0.1)\r\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (0.3.8)\r\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (2.32.3)\r\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (3.5.0)\r\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (0.70.16)\r\n", + "Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5)\r\n", + " Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)\r\n", + "Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (0.31.1)\r\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->flagai==1.8.5) (25.0)\r\n", + "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.11/dist-packages (from diffusers>=0.7.2->flagai==1.8.5) (8.7.0)\r\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from diffusers>=0.7.2->flagai==1.8.5) (2024.11.6)\r\n", + "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.6.7->flagai==1.8.5) (8.1.8)\r\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.11/dist-packages (from nltk>=3.6.7->flagai==1.8.5) (1.5.0)\r\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.3.5->flagai==1.8.5) (2.9.0.post0)\r\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.3.5->flagai==1.8.5) (2025.2)\r\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.3.5->flagai==1.8.5) (2025.2)\r\n", + "Requirement already satisfied: torchmetrics>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-lightning>=1.6.5->flagai==1.8.5) (1.7.1)\r\n", + "Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-lightning>=1.6.5->flagai==1.8.5) (4.13.2)\r\n", + "Requirement already satisfied: lightning-utilities>=0.6.0.post0 in /usr/local/lib/python3.11/dist-packages (from pytorch-lightning>=1.6.5->flagai==1.8.5) (0.14.3)\r\n", + "Collecting portalocker (from sacrebleu>=2.3.1->flagai==1.8.5)\r\n", + " Downloading portalocker-3.1.1-py3-none-any.whl.metadata (8.6 kB)\r\n", + "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.11/dist-packages (from sacrebleu>=2.3.1->flagai==1.8.5) (0.9.0)\r\n", + "Requirement already satisfied: colorama in /usr/local/lib/python3.11/dist-packages (from sacrebleu>=2.3.1->flagai==1.8.5) (0.4.6)\r\n", + "Requirement already satisfied: lxml in /usr/local/lib/python3.11/dist-packages (from sacrebleu>=2.3.1->flagai==1.8.5) (5.3.1)\r\n", + "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=1.0.2->flagai==1.8.5) (1.15.2)\r\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=1.0.2->flagai==1.8.5) (3.6.0)\r\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.11/dist-packages (from tensorboard>=2.9.0->flagai==1.8.5) (1.4.0)\r\n", + "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.11/dist-packages (from tensorboard>=2.9.0->flagai==1.8.5) (1.72.0rc1)\r\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard>=2.9.0->flagai==1.8.5) (3.7)\r\n", + "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.11/dist-packages (from tensorboard>=2.9.0->flagai==1.8.5) (1.17.0)\r\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard>=2.9.0->flagai==1.8.5) (0.7.2)\r\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard>=2.9.0->flagai==1.8.5) (3.1.3)\r\n", + "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.31.0->flagai==1.8.5) (0.21.1)\r\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate->flagai==1.8.5) (7.0.0)\r\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from ftfy->flagai==1.8.5) (0.2.13)\r\n", + "Requirement already satisfied: attrs>=19.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonlines->flagai==1.8.5) (25.3.0)\r\n", + "Requirement already satisfied: urllib3!=2.2.0,<3,>=1.25.4 in /usr/local/lib/python3.11/dist-packages (from botocore<1.39.0,>=1.38.11->boto3>=1.17.32->flagai==1.8.5) (2.4.0)\r\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (3.11.18)\r\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.24.0->datasets>=2.0.0->flagai==1.8.5) (1.1.0)\r\n", + "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (1.3.8)\r\n", + "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (1.2.4)\r\n", + "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (0.1.1)\r\n", + "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2025.1.0)\r\n", + "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2022.1.0)\r\n", + "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2.4.1)\r\n", + "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf>=2.0.0->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (4.9.3)\r\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=2.0.0->flagai==1.8.5) (3.4.2)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=2.0.0->flagai==1.8.5) (3.10)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets>=2.0.0->flagai==1.8.5) (2025.4.26)\r\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (3.4.2)\r\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (3.1.6)\r\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.4.127)\r\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.4.127)\r\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.4.127)\r\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (9.1.0.70)\r\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.4.5.8)\r\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (11.2.1.3)\r\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (10.3.5.147)\r\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (11.6.1.9)\r\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.3.1.170)\r\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (0.6.2)\r\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2.21.5)\r\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.4.127)\r\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (12.4.127)\r\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (3.2.0)\r\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (1.13.1)\r\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (1.3.0)\r\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard>=2.9.0->flagai==1.8.5) (3.0.2)\r\n", + "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata->diffusers>=0.7.2->flagai==1.8.5) (3.21.0)\r\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (2.6.1)\r\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (1.3.2)\r\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (1.6.0)\r\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (6.4.3)\r\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (0.3.1)\r\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets>=2.0.0->flagai==1.8.5) (1.20.0)\r\n", + "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2024.2.0)\r\n", + "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2022.1.0)\r\n", + "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (1.3.0)\r\n", + "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2024.2.0)\r\n", + "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->taming-transformers-rom1504==0.0.6->flagai==1.8.5) (2024.2.0)\r\n", + "Downloading taming_transformers_rom1504-0.0.6-py3-none-any.whl (51 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m51.5/51.5 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading sacrebleu-2.5.1-py3-none-any.whl (104 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m104.1/104.1 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading ftfy-6.3.1-py3-none-any.whl (44 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)\r\n", + "Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m193.6/193.6 kB\u001b[0m \u001b[31m11.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading portalocker-3.1.1-py3-none-any.whl (19 kB)\r\n", + "Building wheels for collected packages: flagai, rouge-score\r\n", + " Building wheel for flagai (setup.py) ... \u001b[?25l\u001b[?25hdone\r\n", + " Created wheel for flagai: filename=flagai-1.8.5-py3-none-any.whl size=740613 sha256=48ca4c27a34abf880aa9a8673bb4285c7cc1d5160dab909f24f893435c3c0f36\r\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-26izrlee/wheels/13/22/48/d7c37afdddf59498b9426d2c537677807d62790ab7f7bfc024\r\n", + " Building wheel for rouge-score (setup.py) ... \u001b[?25l\u001b[?25hdone\r\n", + " Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=25487898f1317c4a0cde843464067d126aaa1f7da920b5874014ad2f8577eba9\r\n", + " Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148\r\n", + "Successfully built flagai rouge-score\r\n", + "Installing collected packages: portalocker, jsonlines, ftfy, fsspec, taming-transformers-rom1504, sacrebleu, rouge-score, flagai\r\n", + " Attempting uninstall: fsspec\r\n", + " Found existing installation: fsspec 2025.3.2\r\n", + " Uninstalling fsspec-2025.3.2:\r\n", + " Successfully uninstalled fsspec-2025.3.2\r\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", + "cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.\r\n", + "bigframes 1.42.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.\r\n", + "gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.3.0 which is incompatible.\u001b[0m\u001b[31m\r\n", + "\u001b[0mSuccessfully installed flagai-1.8.5 fsspec-2025.3.0 ftfy-6.3.1 jsonlines-4.0.0 portalocker-3.1.1 rouge-score-0.1.2 sacrebleu-2.5.1 taming-transformers-rom1504-0.0.6\r\n" + ] + } + ], + "source": [ + "!pip install pytorch_lightning==1.9.5\n", + "!git clone https://github.com/whybfq/FlagAI.git\n", + "!pip install git+https://github.com/whybfq/FlagAI.git" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ad5dd4f4", + "metadata": { + "execution": { + "iopub.execute_input": "2025-06-08T11:44:21.289989Z", + "iopub.status.busy": "2025-06-08T11:44:21.289756Z", + "iopub.status.idle": "2025-06-08T11:44:21.295250Z", + "shell.execute_reply": "2025-06-08T11:44:21.294482Z" + }, + "papermill": { + "duration": 0.033671, + "end_time": "2025-06-08T11:44:21.296395", + "exception": false, + "start_time": "2025-06-08T11:44:21.262724", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/kaggle/working/FlagAI/examples/AltDiffusion\n" + ] + } + ], + "source": [ + "cd FlagAI/examples/AltDiffusion" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "51bf92e0", + "metadata": { + "execution": { + "iopub.execute_input": "2025-06-08T11:44:21.349394Z", + "iopub.status.busy": "2025-06-08T11:44:21.349206Z", + "iopub.status.idle": "2025-06-08T11:53:04.783514Z", + "shell.execute_reply": "2025-06-08T11:53:04.782487Z" + }, + "papermill": { + "duration": 523.463272, + "end_time": "2025-06-08T11:53:04.785131", + "exception": false, + "start_time": "2025-06-08T11:44:21.321859", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dreambooth.py generate.py imgs instance_images LICENSE README.md\r\n", + "[2025-06-08 11:44:35,411] [INFO] [logger.py:85:log_dist] [Rank -1] Unsupported bmtrain\r\n", + "******************** text2img altdiffusion-m18\r\n", + "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.\r\n", + " rank_zero_deprecation(\r\n", + "model files:['config.yaml', 'model.ckpt', 'preprocessor_config.json', 'sentencepiece.bpe.model', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json']\r\n", + "Downloading: 0%| | 0.00/1.81k [00:00\r\n", + " scaler.step(optimizer) # 更新参数\r\n", + " ^^^^^^^^^^^^^^^^^^^^^^\r\n", + " File \"/usr/local/lib/python3.11/dist-packages/torch/amp/grad_scaler.py\", line 457, in step\r\n", + " retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)\r\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\r\n", + " File \"/usr/local/lib/python3.11/dist-packages/torch/amp/grad_scaler.py\", line 352, in _maybe_opt_step\r\n", + " retval = optimizer.step(*args, **kwargs)\r\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\r\n", + " File \"/usr/local/lib/python3.11/dist-packages/torch/optim/optimizer.py\", line 493, in wrapper\r\n", + " out = func(*args, **kwargs)\r\n", + " ^^^^^^^^^^^^^^^^^^^^^\r\n", + " File \"/usr/local/lib/python3.11/dist-packages/torch/optim/optimizer.py\", line 91, in _use_grad\r\n", + " ret = func(self, *args, **kwargs)\r\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^\r\n", + " File \"/usr/local/lib/python3.11/dist-packages/torch/optim/adamw.py\", line 232, in step\r\n", + " has_complex = self._init_group(\r\n", + " ^^^^^^^^^^^^^^^^^\r\n", + " File \"/usr/local/lib/python3.11/dist-packages/torch/optim/adamw.py\", line 171, in _init_group\r\n", + " state[\"exp_avg\"] = torch.zeros_like(\r\n", + " ^^^^^^^^^^^^^^^^^\r\n", + "torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 10.12 MiB is free. Process 4290 has 14.73 GiB memory in use. Of the allocated memory 14.37 GiB is allocated by PyTorch, and 227.71 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)\r\n" + ] + } + ], + "source": [ + "!ls\n", + "!python dreambooth.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cc43ba3", + "metadata": { + "papermill": { + "duration": 0.103845, + "end_time": "2025-06-08T11:53:04.999480", + "exception": false, + "start_time": "2025-06-08T11:53:04.895635", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "machine_shape": "hm" + }, + "kaggle": { + "accelerator": "nvidiaTeslaT4", + "dataSources": [], + "dockerImageVersionId": 31041, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + }, + "papermill": { + "default_parameters": {}, + "duration": 638.08639, + "end_time": "2025-06-08T11:53:05.471504", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2025-06-08T11:42:27.385114", + "version": "2.6.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/AltCLIP-m18/altclip_inference.py b/examples/AltCLIP-m18/altclip_inference.py index 7b466654..12a34ea8 100644 --- a/examples/AltCLIP-m18/altclip_inference.py +++ b/examples/AltCLIP-m18/altclip_inference.py @@ -6,7 +6,7 @@ loader = AutoLoader( task_name="txt_img_matching", - model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models) + model_name="AltCLIP-XLMR-L-m18", # Load the checkpoints from Modelhub(model.baai.ac.cn/models) model_dir="./checkpoints" ) @@ -23,11 +23,11 @@ def inference(): image = Image.open("./examples/AltCLIP-m18//dog.jpeg") image = transform(image) image = torch.tensor(image["pixel_values"]).to(device) - tokenizer_out = tokenizer(["a rat", "a dog", "a cat"], - padding=True, - truncation=True, - max_length=77, - return_tensors='pt') + tokenizer_out = tokenizer(["a rat", "a dog", "a cat"], + padding=True, + truncation=True, + max_length=77, + return_tensors='pt') text = tokenizer_out["input_ids"].to(device) attention_mask = tokenizer_out["attention_mask"].to(device) @@ -38,5 +38,5 @@ def inference(): print(text_probs.cpu().numpy()[0].tolist()) -if __name__=="__main__": - inference() \ No newline at end of file +if __name__ == "__main__": + inference() diff --git a/examples/AltDiffusion-m18/generate.py b/examples/AltDiffusion-m18/generate.py index 98ee65d9..b78cf4d0 100755 --- a/examples/AltDiffusion-m18/generate.py +++ b/examples/AltDiffusion-m18/generate.py @@ -16,9 +16,14 @@ model.eval() model.to(device) predictor = Predictor(model) -prompt = "Daenerys Targaryen as a mermeid with a piercing gaze wearing an enchanted bikini in an underwater magical forest, highly detailed face, realistic face, beautiful detailed eyes, fantasy art, in the style of artgerm, illustration, epic, fantasy, intricate, hyper detailed, artstation, concept art, smooth, sharp focus, ray tracing, vibrant, photorealistic" -negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, extra head, extra legs,fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" +prompt = "สาวสวย" +# negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, extra head, extra legs,fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" seed = 553124 -predictor.predict_generate_images( - prompt=prompt,negative_prompt=negative_prompt,seed=seed -) + +result = predictor.predict_generate_images( + prompt=prompt, + # negative_prompt=negative_prompt, + outpath="./AltDiffusionOutputs", + ddim_steps=50, + seed=seed) +print(type(result), result) diff --git a/examples/AltDiffusion/dreambooth.py b/examples/AltDiffusion/dreambooth.py index 404a8569..2815e1e1 100644 --- a/examples/AltDiffusion/dreambooth.py +++ b/examples/AltDiffusion/dreambooth.py @@ -1,9 +1,3 @@ -# Copyright © 2022 BAAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License") -# -# https://huggingface.co/spaces/BAAI/dreambooth-altdiffusion/blob/main/train_dreambooth.py - import os import sys import random @@ -12,66 +6,73 @@ import torch from torch.utils.data import Dataset +from torch.cuda.amp import autocast, GradScaler # 混合精度训练 from PIL import Image from torchvision import transforms -from flagai.trainer import Trainer from flagai.auto_model.auto_loader import AutoLoader +# 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -instance_data_dir = "./examples/AltDiffusion/instance_images" -instance_prompt = "<鸣人>男孩" - -with_prior_preservation = False -class_data_dir = "Mix" -class_prompt = "男孩" -prior_loss_weight = 1.0 -num_class_images = 10 -resolution = 512 -center_crop = True - -train_text_encoder = False -train_only_unet = True - -num_train_epochs = 500 -batch_size = 4 -learning_rate = 5e-6 -adam_beta1 = 0.9 -adam_beta2 = 0.999 -adam_weight_decay = 1e-2 -adam_epsilon = 1e-08 - +# =============== 训练参数配置 =============== +instance_data_dir = "./instance_images" # 实例图片目录 +instance_prompt = "smile😁" # 实例提示词(包含特殊标识符) + +with_prior_preservation = False # 是否使用先验保留 +class_data_dir = "Mix" # 类别图片目录 +class_prompt = "smile" # 类别提示词 +prior_loss_weight = 1.0 # 先验损失权重 +num_class_images = 4 # 类别图片数量 +resolution = 128 # 图片分辨率 +center_crop = False # 是否中心裁剪 + +train_text_encoder = False # 是否训练文本编码器 +train_only_unet = True # 是否仅训练UNet + +# 训练超参数 +num_train_epochs = 10 # 训练轮数 +batch_size = 2 # 批次大小 +learning_rate = 5e-6 # 学习率 +adam_beta1 = 0.9 # Adam优化器参数 +adam_beta2 = 0.999 # Adam优化器参数 +adam_weight_decay = 1e-2 # 权重衰减 +adam_epsilon = 1e-08 # 数值稳定性常数 + +# =============== 模型初始化 =============== +# 加载AltDiffusion-m18文本生成图像模型 auto_loader = AutoLoader(task_name="text2img", - model_name="AltDiffusion") + model_name="AltDiffusion-m18") -model = auto_loader.get_model() -tokenizer = model.tokenizer +model = auto_loader.get_model() # 获取模型 +tokenizer = model.tokenizer # 获取分词器 +# =============== 数据集定义 =============== class DreamBoothDataset(Dataset): def __init__( - self, - instance_data_root, - instance_prompt, - tokenizer, - class_data_root=None, - class_prompt=None, - size=512, - center_crop=False, + self, + instance_data_root, # 实例图片路径 + instance_prompt, # 实例提示词 + tokenizer, # 分词器 + class_data_root=None, # 类别图片路径 + class_prompt=None, # 类别提示词 + size=512, # 图片大小 + center_crop=False, # 是否中心裁剪 ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + # 处理实例图片 self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images + # 处理类别图片(用于先验保留) if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) @@ -83,12 +84,13 @@ def __init__( else: self.class_data_root = None + # 图片预处理流程 self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + transforms.Normalize([0.5], [0.5]), # 归一化到[-1, 1] ] ) @@ -97,11 +99,13 @@ def __len__(self): def __getitem__(self, index): example = {} + # 加载实例图片 path = self.instance_images_path[index % self.num_instance_images] instance_image = Image.open(path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - + + # 应用预处理并获取提示词token example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, @@ -109,9 +113,9 @@ def __getitem__(self, index): truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - print('*'*20, "instance_prompt=", self.instance_prompt) example["caption"] = self.instance_prompt + # 加载类别图片(如果使用先验保留) if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": @@ -126,6 +130,7 @@ def __getitem__(self, index): return example +# =============== 创建数据集和数据加载器 =============== train_dataset = DreamBoothDataset( instance_data_root=instance_data_dir, instance_prompt=instance_prompt, @@ -137,12 +142,12 @@ def __getitem__(self, index): ) def collate_fn(examples): + # 批量处理数据 input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] captions = [example["caption"] for example in examples] - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. + # 合并实例和类别数据(如果使用先验保留) if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -156,6 +161,7 @@ def collate_fn(examples): "input_ids": input_ids, "pixel_values": pixel_values, "caption": captions, + "txt": captions } return batch @@ -163,54 +169,73 @@ def collate_fn(examples): train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn ) -vae = model.first_stage_model -text_encoder = model.cond_stage_model -unet = model.model.diffusion_model +# =============== 模型组件准备 =============== +vae = model.first_stage_model # 变分自编码器 +text_encoder = model.cond_stage_model # 文本编码器 +unet = model.model.diffusion_model # 扩散模型(UNet) +# 冻结不需要训练的组件 vae.requires_grad_(False) if not train_text_encoder: text_encoder.requires_grad_(False) +# =============== 优化器设置 =============== optimizer_class = torch.optim.AdamW params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if train_text_encoder else unet.parameters()) optimizer = optimizer_class( - params_to_optimize, - lr=learning_rate, - betas=(adam_beta1, adam_beta2), - weight_decay=adam_weight_decay, - eps=adam_epsilon, + params_to_optimize, + lr=learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, ) +# =============== 混合精度训练 =============== +scaler = GradScaler() # 梯度缩放器(用于混合精度训练) + +# 将模型移到设备 model.to(device) +vae = model.first_stage_model.to(device) +text_encoder = model.cond_stage_model.to(device) +unet = model.model.diffusion_model.to(device) + +# 训练循环 for epoch in range(num_train_epochs): unet.train() if train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): - #x = batch["pixel_values"].to(device) - #x = model.encode_first_stage(batch["pixel_values"]).to(device) - #c = batch["caption"] + # 获取输入数据 x, c = model.get_input(batch, "pixel_values") - if with_prior_preservation: - x, x_prior = torch.chunk(x, 2, dim=0) - c, c_prior = torch.chunk(c, 2, dim=0) - loss, _ = model(x, c) - prior_loss, _ = model(x_prior, c_prior) - loss = loss + prior_loss_weight * prior_loss - else: - loss, _ = model(x, c) - - print('*'*20, "loss=", str(loss.detach().item())) - - loss.backward() - optimizer.step() - optimizer.zero_grad() - -## mkdir ./checkpoints/DreamBooth and copy ./checkpoints/AltDiffusion to ./checkpoints/DreamBooth/AltDiffusion -## overwrite model.ckpt for latter usage -chekpoint_path = './checkpoints/DreamBooth/AltDiffusion/model.ckpt' -torch.save(model.state_dict(), chekpoint_path) - + # 混合精度训练 + with autocast(): # 自动混合精度上下文 + if with_prior_preservation: + # 分离实例和先验数据 + x, x_prior = torch.chunk(x, 2, dim=0) + c, c_prior = torch.chunk(c, 2, dim=0) + # 计算实例损失和先验损失 + loss, _ = model(x, c) + prior_loss, _ = model(x_prior, c_prior) + # 组合损失 + loss = loss + prior_loss_weight * prior_loss + else: + # 仅计算实例损失 + loss, _ = model(x, c) + + print('*' * 20, "loss=", str(loss.detach().item())) + + # 反向传播和优化 + scaler.scale(loss).backward() # 缩放梯度 + scaler.step(optimizer) # 更新参数 + scaler.update() # 更新缩放器 + optimizer.zero_grad() # 清零梯度 + +# =============== 保存训练好的模型 =============== +checkpoint_path = './checkpoints/AltDiffusion-m18-new-trained/model.ckpt' +# 确保目录存在 +os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) +torch.save(model.state_dict(), checkpoint_path) +print(f"Model saved to {checkpoint_path}") diff --git a/examples/AltDiffusion/generate.py b/examples/AltDiffusion/generate.py index bafe1ae1..51af2819 100755 --- a/examples/AltDiffusion/generate.py +++ b/examples/AltDiffusion/generate.py @@ -9,7 +9,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loader = AutoLoader(task_name="text2img", #contrastive learning - model_name="AltDiffusion-m9", + model_name="AltDiffusion-m18", # use m18 to do the experiment model_dir="./checkpoints", fp16=False) @@ -18,5 +18,11 @@ model.to(device) predictor = Predictor(model) predictor.predict_generate_images( - "Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation" -) \ No newline at end of file + prompt="smile😁", + # negative_prompt=negative_prompt, + # outpath="./AltDiffusionOutputs", + ddim_steps=20, + plms=True, + skip_grid=True, # or False if you want a grid image +) + diff --git a/examples/AltDiffusion/instance_images/1.jpg b/examples/AltDiffusion/instance_images/1.jpg new file mode 100644 index 00000000..943c65ca Binary files /dev/null and b/examples/AltDiffusion/instance_images/1.jpg differ diff --git a/examples/AltDiffusion/instance_images/2.jpg b/examples/AltDiffusion/instance_images/2.jpg new file mode 100644 index 00000000..e27788e2 Binary files /dev/null and b/examples/AltDiffusion/instance_images/2.jpg differ diff --git a/examples/AltDiffusion/instance_images/3.jpg b/examples/AltDiffusion/instance_images/3.jpg new file mode 100644 index 00000000..18c3173d Binary files /dev/null and b/examples/AltDiffusion/instance_images/3.jpg differ diff --git a/examples/AltDiffusion/instance_images/4.jpg b/examples/AltDiffusion/instance_images/4.jpg new file mode 100644 index 00000000..4b1e0596 Binary files /dev/null and b/examples/AltDiffusion/instance_images/4.jpg differ diff --git a/flagai/model/mm/AltCLIP.py b/flagai/model/mm/AltCLIP.py index f477b378..0aae80af 100755 --- a/flagai/model/mm/AltCLIP.py +++ b/flagai/model/mm/AltCLIP.py @@ -1,10 +1,17 @@ +from typing import Tuple, Any, Optional, Union + from transformers.models.clip.modeling_clip import * import torch.nn as nn import torch -from transformers.models.clip.modeling_clip import CLIPOutput -from transformers import CLIPProcessor +from transformers.models.clip.modeling_clip import CLIPOutput, CLIPVisionTransformer, clip_loss +from transformers import CLIPProcessor, CLIPConfig, CLIPVisionConfig import os + +from transformers.models.clip.modeling_flax_clip import CLIP_VISION_INPUTS_DOCSTRING, CLIP_INPUTS_DOCSTRING +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings + from flagai.model.base_model import BaseModel +from dataclasses import dataclass from .modeling_berts import BertSeriesConfig, RobertaSeriesConfig, BertSeriesModelWithTransformation, RobertaSeriesModelWithTransformation diff --git a/flagai/model/mm/AltDiffusion.py b/flagai/model/mm/AltDiffusion.py index 9a6b12cf..64ac9e0e 100755 --- a/flagai/model/mm/AltDiffusion.py +++ b/flagai/model/mm/AltDiffusion.py @@ -515,6 +515,7 @@ class LatentDiffusion(DDPM): def to(self, device): self.device = device self.cond_stage_model.to(device) + # self.logvar = self.logvar.to(dtype=torch.float16) # Move logvar to the same device super().to(device) def __init__(self, @@ -531,7 +532,7 @@ def __init__(self, tokenizer=None, *args, **kwargs): - self.device = "cpu" + self.device = "cpu" # original: cpu self.tokenizer = tokenizer self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std @@ -1266,8 +1267,9 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): x_recon = fold(o) / normalization else: - with autocast(): - x_recon = self.model(x_noisy, t, **cond) + # from torch import autocast + with autocast(): + x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] @@ -1319,7 +1321,11 @@ def p_losses(self, x_start, cond, t, noise=None): loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - logvar_t = self.logvar[t].to(self.device) + logvar_t = self.logvar.to(t.device)[t] # logvar_t = self.logvar[t].to(self.device) + print(f"self.logvar device: {self.logvar.device}, t device: {t.device}, self.device: {self.device}") + print("self.logvar dtype:", self.logvar.dtype) + print("t dtype:", t.dtype) + loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: diff --git a/flagai/model/mm/AltDiffusionM18.py b/flagai/model/mm/AltDiffusionM18.py index acb30e95..a156733e 100755 --- a/flagai/model/mm/AltDiffusionM18.py +++ b/flagai/model/mm/AltDiffusionM18.py @@ -15,8 +15,9 @@ from flagai.model.mm.Sampler import DDIMSampler from flagai.model.base_model import BaseModel from torch.cuda.amp import autocast as autocast +from flagai.model.mm.modules.distributions.distributions import DiagonalGaussianDistribution +from flagai.model.mm.modules.ema import LitEma import pytorch_lightning as pl -from torch.cuda.amp import autocast as autocast __conditioning_keys__ = { diff --git a/flagai/model/mm/Unets/Unet.py b/flagai/model/mm/Unets/Unet.py index 798f8ff5..b8fea8ee 100644 --- a/flagai/model/mm/Unets/Unet.py +++ b/flagai/model/mm/Unets/Unet.py @@ -83,7 +83,7 @@ def forward(self, x, emb, context=None): elif isinstance(layer, SpatialTransformer): x = layer(x, context) else: - x = layer(x.half()) + x = layer(x.half()) # 让自动混合精度控制数据精度 x = layer(x.half()) return x diff --git a/flagai/model/mm/utils.py b/flagai/model/mm/utils.py index e9e6cee5..cfe6f779 100755 --- a/flagai/model/mm/utils.py +++ b/flagai/model/mm/utils.py @@ -102,14 +102,21 @@ def count_params(model, verbose=False): return total_params +# def instantiate_from_config(config): +# if not "target" in config: +# if config == '__is_first_stage__': +# return None +# elif config == "__is_unconditional__": +# return None +# raise KeyError("Expected key `target` to instantiate.") +# return get_obj_from_str(config["target"])(**config.get("params", dict())) + def instantiate_from_config(config): - if not "target" in config: - if config == '__is_first_stage__': - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) + target = get_obj_from_str(config["target"]) + params = config.get("params", {}) + # Filter out unexpected parameters + valid_params = {k: v for k, v in params.items() if k in target.__init__.__code__.co_varnames} + return target(**valid_params) def get_obj_from_str(string, reload=False): diff --git a/mindscope/DreamBooth-AltDiffusion.py b/mindscope/DreamBooth-AltDiffusion.py new file mode 100644 index 00000000..5f395139 --- /dev/null +++ b/mindscope/DreamBooth-AltDiffusion.py @@ -0,0 +1,11 @@ +from modelscope import AltDiffusionPipeline, DPMSolverMultistepScheduler +import torch + +pipe = AltDiffusionPipeline.from_pretrained("BAAI/DreamBooth-AltDiffusion") +pipe = pipe.to("cuda") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + +prompt = "一张<鸣人>男孩的照片" +image = pipe(prompt, num_inference_steps=25).images[0] +image.show() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f9250dbc..53f56d79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ PyYAML==5.4.1 deepspeed==0.6.5 flash-attn==1.0.2 bminf +pytorch_lightning==1.9.5