diff --git a/examples/HeartMuLa_Colab_Example.ipynb b/examples/HeartMuLa_Colab_Example.ipynb new file mode 100644 index 0000000..f412c4d --- /dev/null +++ b/examples/HeartMuLa_Colab_Example.ipynb @@ -0,0 +1,487 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "KZP-87a0WRpp" + }, + "source": [ + "# \ud83c\udfb5 HeartMuLa: Music Generation Inference\n", + "\n", + "This notebook allows you to run **HeartMuLa** (Heart Music Language), a state-of-the-art open-source music generation model.\n", + "\n", + " GitHub Repository | ArXiv Paper\n", + "\n", + "### \u26a0\ufe0f Runtime Requirement\n", + "**GPU is required.** Please ensure you are running on a GPU runtime.\n", + "* **Colab:** Runtime > Change runtime type > T4 GPU (Standard) or A100 (Premium).\n", + "* **Note:** T4 GPUs have limited memory. We use `lazy_load=True` to minimize VRAM usage." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2PnjQkCYWRpq" + }, + "source": [ + "## 1. Setup Environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "okdeSQ0uWRpr", + "outputId": "4ba66d27-7379-461c-91bd-e3b776f624ff" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2705 Installation complete.\n" + ] + } + ], + "source": [ + "# @title Install Dependencies\n", + "# @markdown This step clones the repository and installs the required Python packages.\n", + "\n", + "import os\n", + "import sys\n", + "from IPython.display import clear_output\n", + "\n", + "!nvidia-smi\n", + "\n", + "print(\"Installing dependencies... (this may take 1-2 minutes)\")\n", + "\n", + "# Clone repo\n", + "if not os.path.exists(\"/content/heartlib\"):\n", + " %cd /content\n", + " !git clone https://github.com/HeartMuLa/heartlib.git\n", + "\n", + "%cd /content/heartlib\n", + "\n", + "# Install dependencies\n", + "!pip install . --quiet\n", + "!pip install huggingface_hub --quiet\n", + "\n", + "clear_output()\n", + "print(\"\u2705 Installation complete.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5NtsvvuRWRpr" + }, + "source": [ + "## 2. Download Models\n", + "We need to download three components from Hugging Face:\n", + "1. **Configuration & Tokenizer**: `HeartMuLa/HeartMuLaGen`\n", + "2. **Music Model (3B)**: `HeartMuLa/HeartMuLa-oss-3B`\n", + "3. **Audio Codec**: `HeartMuLa/HeartCodec-oss`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "cellView": "form", + "id": "2OSp2ezkWRps", + "outputId": "e3596a60-c1f6-40f4-c186-0272ad843b90" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u2705 Models downloaded successfully to ./ckpt\n" + ] + } + ], + "source": [ + "# @title Download Checkpoints\n", + "\n", + "import os\n", + "\n", + "# Define paths\n", + "ckpt_dir = \"/content/heartlib/ckpt\"\n", + "os.makedirs(ckpt_dir, exist_ok=True)\n", + "\n", + "print(\"Downloading configuration and tokenizer...\")\n", + "!huggingface-cli download HeartMuLa/HeartMuLaGen --local-dir {ckpt_dir} --local-dir-use-symlinks False\n", + "\n", + "print(\"Downloading HeartMuLa-oss-3B model...\")\n", + "!huggingface-cli download HeartMuLa/HeartMuLa-oss-3B --local-dir {ckpt_dir}/HeartMuLa-oss-3B --local-dir-use-symlinks False\n", + "\n", + "print(\"Downloading HeartCodec-oss...\")\n", + "!huggingface-cli download HeartMuLa/HeartCodec-oss --local-dir {ckpt_dir}/HeartCodec-oss --local-dir-use-symlinks False\n", + "\n", + "clear_output()\n", + "print(\"\u2705 Models downloaded successfully to ./ckpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "40MGZow3WRps" + }, + "source": [ + "## 3. Load Model Pipeline\n", + "We initialize the pipeline here. We auto-detect if your GPU supports `bfloat16` (Ampere/A100) or if we should fallback to `float16` / `float32` (T4/Pascal)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LeHTVR5xWRps", + "outputId": "8ab95ad7-62ad-46b8-9dfa-51ab1136f7c6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using bfloat16 precision (Optimal)\n", + "Model components will be loaded to devices as specified:\n", + " mula: cuda\n", + " codec: cuda\n", + "\u2705 Pipeline loaded successfully!\n" + ] + } + ], + "source": [ + "import torch\n", + "from heartlib import HeartMuLaGenPipeline\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "if torch.cuda.is_available() and torch.cuda.is_bf16_supported():\n", + " mula_dtype = torch.bfloat16\n", + " print(\"Using bfloat16 precision (Optimal)\")\n", + "else:\n", + " mula_dtype = torch.float16\n", + " print(\"Using float16 precision (Fallback for T4/V100)\")\n", + "codec_dtype = torch.float32\n", + "pipe = HeartMuLaGenPipeline.from_pretrained(\n", + " \"./ckpt\",\n", + " device={\"mula\": device, \"codec\": device},\n", + " dtype={\"mula\": mula_dtype, \"codec\": codec_dtype},\n", + " version=\"3B\",\n", + " lazy_load=True\n", + ")\n", + "print(\"\u2705 Pipeline loaded successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PF4z9hLaWRps" + }, + "source": [ + "## 4. Generate Music\n", + "\n", + "Enter your Lyrics and Tags below.\n", + "* **Lyrics**: Structure them with tags like `[Verse]`, `[Chorus]`.\n", + "* **Tags**: Comma-separated genres or moods (e.g., `pop, happy, piano`).\n", + "* **Parameters**:\n", + " * `Max Duration`: Length of the song in seconds (1 min = 60000ms).\n", + " * `CFG Scale`: Higher values follow text more closely (1.5 is standard).\n", + " * `Temperature`: Creativity (1.0 is standard)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 245, + "referenced_widgets": [ + "a9890e1691ff4c8d89cd5a8e545bed0e", + "d7d86758aafb4b64adeebb2e81edc71f", + "750d9ec077c44b6a9f04f222cc51a333", + "36486be0a6ed4be99f69b7dac85c11f6", + "22ae4ffc8f8847898a77c73f4e2cf2ef", + "69644cf1574d4842ae5e53bfe713d655", + "77d94de8de3c46da85b6b4310b5f8be2", + "c5b1d0249fff4d5a8c7a32e674229a6a", + "c4014d4bebb64d558ced03b470c38543", + "460a689a399f4d9a86cc6060dcfb598d", + "138679070de94d4c9663bfe84dbfcf19" + ] + }, + "id": "JO8dmoWjWRps", + "outputId": "0e122426-407f-44fb-dab4-457d12ddf307" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating music... Please wait.\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "# @title \ud83c\udfa7 Listen to the Song\n", + "from IPython.display import Audio, display\n", + "\n", + "if os.path.exists(output_path):\n", + " display(Audio(output_path))\n", + "else:\n", + " print(\"Audio file not found. Please run the generation cell first.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "P8ZgHP9gWRpt", + "outputId": "c71680f8-9743-4bc5-cb4f-8b9c1a7b588d" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "download(\"download_a80c3e12-7436-4a02-b10a-80fd88f924dd\", \"my_generation.mp3\", 482157)" + ] + }, + "metadata": {} + } + ], + "source": [ + "# @title (Optional) Download MP3\n", + "# @markdown Run this to download the generated file to your local computer.\n", + "from google.colab import files\n", + "if os.path.exists(output_path):\n", + " files.download(output_path)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "colab": { + "provenance": [], + "collapsed_sections": [ + "5NtsvvuRWRpr" + ], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file