diff --git a/examples/tts.py b/examples/tts.py new file mode 100644 index 0000000..0b5fe8e --- /dev/null +++ b/examples/tts.py @@ -0,0 +1,220 @@ +"""Kokoro-82M TTS Example. + +This example demonstrates text-to-speech synthesis using the Kokoro-82M model +with PyGPUkit's native LSTM kernel. + +Usage: + python examples/tts.py + python examples/tts.py --text "Hello world" --voice af_heart + python examples/tts.py --model F:/LLM/Kokoro-82M --output speech.wav +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + + +def test_lstm_kernel(): + """Test the native LSTM kernel works correctly.""" + import numpy as np + + import pygpukit as pk + + print("Testing native LSTM kernel...") + + batch = 2 + seq_len = 10 + input_size = 64 + hidden_size = 128 + + # Create random test inputs + x = pk.from_numpy(np.random.randn(batch, seq_len, input_size).astype(np.float32)) + W_ih = pk.from_numpy(np.random.randn(4 * hidden_size, input_size).astype(np.float32) * 0.1) + W_hh = pk.from_numpy(np.random.randn(4 * hidden_size, hidden_size).astype(np.float32) * 0.1) + b_ih = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + b_hh = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + + # Forward LSTM + output, h_n, c_n = pk.lstm_forward(x, W_ih, W_hh, b_ih, b_hh) + + print(f" Input shape: {x.shape}") + print(f" Output shape: {output.shape}") + print(f" h_n shape: {h_n.shape}") + print(f" c_n shape: {c_n.shape}") + + # Verify output is not all zeros + out_np = output.to_numpy() + assert not np.allclose(out_np, 0), "LSTM output should not be all zeros" + + print(" LSTM kernel test PASSED!") + return True + + +def test_bidirectional_lstm(): + """Test bidirectional LSTM.""" + import numpy as np + + import pygpukit as pk + + print("Testing bidirectional LSTM...") + + batch = 2 + seq_len = 10 + input_size = 64 + hidden_size = 128 + + # Create random test inputs + x = pk.from_numpy(np.random.randn(batch, seq_len, input_size).astype(np.float32)) + + # Forward direction weights + W_ih_fwd = pk.from_numpy(np.random.randn(4 * hidden_size, input_size).astype(np.float32) * 0.1) + W_hh_fwd = pk.from_numpy(np.random.randn(4 * hidden_size, hidden_size).astype(np.float32) * 0.1) + b_ih_fwd = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + b_hh_fwd = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + + # Backward direction weights + W_ih_bwd = pk.from_numpy(np.random.randn(4 * hidden_size, input_size).astype(np.float32) * 0.1) + W_hh_bwd = pk.from_numpy(np.random.randn(4 * hidden_size, hidden_size).astype(np.float32) * 0.1) + b_ih_bwd = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + b_hh_bwd = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + + # Bidirectional LSTM + output, h_n, c_n = pk.lstm_bidirectional( + x, + W_ih_fwd, + W_hh_fwd, + b_ih_fwd, + b_hh_fwd, + W_ih_bwd, + W_hh_bwd, + b_ih_bwd, + b_hh_bwd, + ) + + print(f" Input shape: {x.shape}") + print(f" Output shape: {output.shape} (2x hidden due to bidirectional)") + print(f" h_n shape: {h_n.shape}") + print(f" c_n shape: {c_n.shape}") + + # Verify shapes + assert output.shape == (batch, seq_len, 2 * hidden_size) + assert h_n.shape == (2, batch, hidden_size) + assert c_n.shape == (2, batch, hidden_size) + + print(" Bidirectional LSTM test PASSED!") + return True + + +def benchmark_lstm(): + """Benchmark LSTM performance.""" + import numpy as np + + import pygpukit as pk + + print("\nBenchmarking LSTM performance...") + + batch = 8 + seq_len = 100 + input_size = 768 + hidden_size = 512 + + # Create test inputs (typical TTS dimensions) + x = pk.from_numpy(np.random.randn(batch, seq_len, input_size).astype(np.float32)) + W_ih = pk.from_numpy(np.random.randn(4 * hidden_size, input_size).astype(np.float32) * 0.1) + W_hh = pk.from_numpy(np.random.randn(4 * hidden_size, hidden_size).astype(np.float32) * 0.1) + b_ih = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + b_hh = pk.from_numpy(np.zeros(4 * hidden_size, dtype=np.float32)) + + # Warmup + for _ in range(3): + output, h_n, c_n = pk.lstm_forward(x, W_ih, W_hh, b_ih, b_hh) + + # Benchmark + iterations = 10 + start = time.perf_counter() + for _ in range(iterations): + output, h_n, c_n = pk.lstm_forward(x, W_ih, W_hh, b_ih, b_hh) + elapsed = time.perf_counter() - start + + ms_per_call = (elapsed / iterations) * 1000 + print(f" Config: batch={batch}, seq_len={seq_len}, input={input_size}, hidden={hidden_size}") + print(f" Time per forward: {ms_per_call:.2f} ms") + print(f" Throughput: {(batch * seq_len) / (ms_per_call / 1000):.0f} tokens/sec") + + return ms_per_call + + +def main(): + parser = argparse.ArgumentParser(description="Kokoro-82M TTS Example") + parser.add_argument("--model", type=str, default="F:/LLM/Kokoro-82M", help="Model path") + parser.add_argument( + "--text", + type=str, + default="Hello, this is a test of the Kokoro text to speech system.", + help="Text to synthesize", + ) + parser.add_argument("--voice", type=str, default="af_heart", help="Voice to use") + parser.add_argument("--output", type=str, default="output.wav", help="Output WAV file") + parser.add_argument("--test-only", action="store_true", help="Only run LSTM tests") + args = parser.parse_args() + + print("=" * 60) + print("PyGPUkit TTS Example - Kokoro-82M") + print("=" * 60) + + # Test LSTM kernel + if not test_lstm_kernel(): + print("LSTM kernel test failed!") + return 1 + + if not test_bidirectional_lstm(): + print("Bidirectional LSTM test failed!") + return 1 + + # Benchmark + benchmark_lstm() + + if args.test_only: + print("\nTest-only mode: skipping model loading") + return 0 + + # Try to load the TTS model + model_path = Path(args.model) + if not model_path.exists(): + print(f"\nModel not found at {model_path}") + print("Please download Kokoro-82M from HuggingFace:") + print(" huggingface-cli download hexgrad/Kokoro-82M --local-dir F:/LLM/Kokoro-82M") + return 1 + + print(f"\nLoading model from: {model_path}") + + from pygpukit.tts.kokoro import KokoroModel + + model = KokoroModel.from_pretrained(model_path, voice=args.voice) + model.print_info() + + print(f'\nSynthesizing: "{args.text}"') + start = time.perf_counter() + result = model.synthesize(args.text, voice=args.voice) + elapsed = time.perf_counter() - start + + # Phonemes may contain IPA characters that can't print on Windows cp932 + try: + print(f" Phonemes: {result.phonemes}") + except UnicodeEncodeError: + print(f" Phonemes: (contains IPA characters, {len(result.phonemes)} chars)") + print(f" Duration: {result.duration_sec:.2f} sec") + print(f" Synthesis time: {elapsed * 1000:.2f} ms") + print(f" RTF: {elapsed / result.duration_sec:.3f}x") + + result.to_wav(args.output) + print(f"\nAudio saved to: {args.output}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index cfde596..e32db94 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -215,6 +215,7 @@ pybind11_add_module(${MODULE_NAME} bindings/nn/norm.cpp bindings/nn/attention.cpp bindings/nn/rope.cpp + bindings/nn/recurrent.cpp # Bindings - GEMM operations (by dtype combination) bindings/gemm/generic.cpp bindings/gemm/fp8xfp8_bf16.cpp diff --git a/native/bindings/bindings_common.hpp b/native/bindings/bindings_common.hpp index 1bd5f92..1ee0532 100644 --- a/native/bindings/bindings_common.hpp +++ b/native/bindings/bindings_common.hpp @@ -35,6 +35,7 @@ void init_nn_activation(py::module_& m); void init_nn_norm(py::module_& m); void init_nn_attention(py::module_& m); void init_nn_rope(py::module_& m); +void init_nn_recurrent(py::module_& m); void init_embedding_lookup(py::module_& m); void init_embedding_kv_cache(py::module_& m); diff --git a/native/bindings/nn/recurrent.cpp b/native/bindings/nn/recurrent.cpp new file mode 100644 index 0000000..25ed5a1 --- /dev/null +++ b/native/bindings/nn/recurrent.cpp @@ -0,0 +1,47 @@ +/** + * NN recurrent layers: LSTM + */ +#include "../bindings_common.hpp" + +void init_nn_recurrent(py::module_& m) { + // LSTM forward (unidirectional) + m.def("lstm_forward", &ops::lstm_forward, + py::arg("x"), + py::arg("W_ih"), py::arg("W_hh"), + py::arg("b_ih"), py::arg("b_hh"), + py::arg("h0"), py::arg("c0"), + py::arg("reverse") = false, + "LSTM forward pass (unidirectional).\n\n" + "Args:\n" + " x: input [batch, seq_len, input_size]\n" + " W_ih: input-to-hidden weights [4*hidden_size, input_size]\n" + " W_hh: hidden-to-hidden weights [4*hidden_size, hidden_size]\n" + " b_ih: input bias [4*hidden_size]\n" + " b_hh: hidden bias [4*hidden_size]\n" + " h0: initial hidden state [batch, hidden_size] or empty\n" + " c0: initial cell state [batch, hidden_size] or empty\n" + " reverse: process sequence in reverse order\n\n" + "Returns:\n" + " tuple of (output, h_n, c_n)\n" + " output: [batch, seq_len, hidden_size]\n" + " h_n: [batch, hidden_size]\n" + " c_n: [batch, hidden_size]"); + + // LSTM bidirectional + m.def("lstm_bidirectional", &ops::lstm_bidirectional, + py::arg("x"), + py::arg("W_ih_fwd"), py::arg("W_hh_fwd"), + py::arg("b_ih_fwd"), py::arg("b_hh_fwd"), + py::arg("W_ih_bwd"), py::arg("W_hh_bwd"), + py::arg("b_ih_bwd"), py::arg("b_hh_bwd"), + "Bidirectional LSTM.\n\n" + "Args:\n" + " x: input [batch, seq_len, input_size]\n" + " W_ih_fwd, W_hh_fwd, b_ih_fwd, b_hh_fwd: forward LSTM weights\n" + " W_ih_bwd, W_hh_bwd, b_ih_bwd, b_hh_bwd: backward LSTM weights\n\n" + "Returns:\n" + " tuple of (output, h_n, c_n)\n" + " output: [batch, seq_len, 2*hidden_size] (concatenated fwd/bwd)\n" + " h_n: [2, batch, hidden_size]\n" + " c_n: [2, batch, hidden_size]"); +} diff --git a/native/bindings/ops_bindings.cpp b/native/bindings/ops_bindings.cpp index 1ffa95a..1db2ee4 100644 --- a/native/bindings/ops_bindings.cpp +++ b/native/bindings/ops_bindings.cpp @@ -1,76 +1,77 @@ -/** - * PyGPUkit Operations Bindings - Main Entry Point - * - * This file calls all init functions from the modular binding files. - * Each category is in its own subdirectory for better organization. - */ -#include "bindings_common.hpp" - -void init_ops_bindings(py::module_& m) { - // Elementwise operations - init_elementwise_binary(m); - init_elementwise_inplace(m); - init_elementwise_compare(m); - - // Unary operations - init_unary_math(m); - init_unary_trig(m); - - // Reduction operations - init_reduction_basic(m); - init_reduction_argmax(m); - init_reduction_softmax(m); - - // Tensor operations - init_tensor_cast(m); - init_tensor_transpose(m); - init_tensor_reshape(m); - init_tensor_repeat(m); - - // Neural network operations - init_nn_activation(m); - init_nn_norm(m); - init_nn_attention(m); - init_nn_rope(m); - - // Embedding operations - init_embedding_lookup(m); - init_embedding_kv_cache(m); - - // GEMM operations (by dtype combination) - init_gemm_generic(m); - init_gemm_fp8xfp8_bf16(m); - init_gemm_fp8xfp8_fp8(m); - init_gemm_fp8xbf16_bf16(m); - init_gemm_nvf4xbf16_bf16(m); - init_gemm_grouped(m); - init_gemm_int(m); - - // GEMV operations - init_gemv_generic(m); - init_gemv_fp8xfp8_bf16(m); - init_gemv_nvf4xbf16_bf16(m); - - // Sampling operations - init_sampling_basic(m); - init_sampling_topk(m); - init_sampling_seed(m); - - // Quantization operations - init_quantize(m); - - // Attention operations - init_paged_attention(m); - - // Continuous batching operations - init_continuous_batching(m); - - // Audio processing operations - init_audio(m); - - // cuBLASLt utility functions - init_cublaslt(m); - - // MoE (Mixture of Experts) operations - init_moe(m); -} +/** + * PyGPUkit Operations Bindings - Main Entry Point + * + * This file calls all init functions from the modular binding files. + * Each category is in its own subdirectory for better organization. + */ +#include "bindings_common.hpp" + +void init_ops_bindings(py::module_& m) { + // Elementwise operations + init_elementwise_binary(m); + init_elementwise_inplace(m); + init_elementwise_compare(m); + + // Unary operations + init_unary_math(m); + init_unary_trig(m); + + // Reduction operations + init_reduction_basic(m); + init_reduction_argmax(m); + init_reduction_softmax(m); + + // Tensor operations + init_tensor_cast(m); + init_tensor_transpose(m); + init_tensor_reshape(m); + init_tensor_repeat(m); + + // Neural network operations + init_nn_activation(m); + init_nn_norm(m); + init_nn_attention(m); + init_nn_rope(m); + init_nn_recurrent(m); + + // Embedding operations + init_embedding_lookup(m); + init_embedding_kv_cache(m); + + // GEMM operations (by dtype combination) + init_gemm_generic(m); + init_gemm_fp8xfp8_bf16(m); + init_gemm_fp8xfp8_fp8(m); + init_gemm_fp8xbf16_bf16(m); + init_gemm_nvf4xbf16_bf16(m); + init_gemm_grouped(m); + init_gemm_int(m); + + // GEMV operations + init_gemv_generic(m); + init_gemv_fp8xfp8_bf16(m); + init_gemv_nvf4xbf16_bf16(m); + + // Sampling operations + init_sampling_basic(m); + init_sampling_topk(m); + init_sampling_seed(m); + + // Quantization operations + init_quantize(m); + + // Attention operations + init_paged_attention(m); + + // Continuous batching operations + init_continuous_batching(m); + + // Audio processing operations + init_audio(m); + + // cuBLASLt utility functions + init_cublaslt(m); + + // MoE (Mixture of Experts) operations + init_moe(m); +} diff --git a/native/ops/nn/nn.cu b/native/ops/nn/nn.cu index fe22915..05d7e77 100644 --- a/native/ops/nn/nn.cu +++ b/native/ops/nn/nn.cu @@ -36,3 +36,4 @@ #include "embedding/embedding.inl" #include "elementwise/inplace.inl" #include "cast/cast.inl" +#include "recurrent/lstm.inl" diff --git a/native/ops/nn/recurrent/lstm.inl b/native/ops/nn/recurrent/lstm.inl new file mode 100644 index 0000000..4063358 --- /dev/null +++ b/native/ops/nn/recurrent/lstm.inl @@ -0,0 +1,254 @@ +/** + * LSTM dispatch implementation + * + * Provides high-level LSTM operations: + * - lstm_forward: unidirectional LSTM + * - lstm_bidirectional: bidirectional LSTM + * + * NOTE: Uses kernel-based copies instead of cudaMemcpy for Driver API compatibility. + */ + +#include "lstm_kernels.cuh" +#include "../../../core/cuda_graph.hpp" + +namespace pygpukit { +namespace ops { + +using namespace nn; + +// ============================================================================ +// LSTM Forward - Unidirectional +// ============================================================================ + +/** + * LSTM forward pass. + * + * Args: + * x: input [batch, seq_len, input_size] + * W_ih: [4*hidden_size, input_size] + * W_hh: [4*hidden_size, hidden_size] + * b_ih: [4*hidden_size] + * b_hh: [4*hidden_size] + * h0: initial hidden [batch, hidden_size] or empty for zeros + * c0: initial cell [batch, hidden_size] or empty for zeros + * reverse: process sequence in reverse order + * + * Returns: + * output: [batch, seq_len, hidden_size] + * h_n: [batch, hidden_size] + * c_n: [batch, hidden_size] + */ +std::tuple lstm_forward( + const GPUArray& x, + const GPUArray& W_ih, + const GPUArray& W_hh, + const GPUArray& b_ih, + const GPUArray& b_hh, + const GPUArray& h0, + const GPUArray& c0, + bool reverse +) { + // Validate inputs + if (x.ndim() != 3) { + throw std::runtime_error("lstm_forward: x must be 3D [batch, seq_len, input_size]"); + } + if (x.dtype() != DataType::Float32) { + throw std::runtime_error("lstm_forward: only float32 supported currently"); + } + + int batch_size = static_cast(x.shape()[0]); + int seq_len = static_cast(x.shape()[1]); + int input_size = static_cast(x.shape()[2]); + int hidden_size = static_cast(W_hh.shape()[1]); + + // Allocate outputs + GPUArray output({static_cast(batch_size), static_cast(seq_len), static_cast(hidden_size)}, DataType::Float32); + GPUArray h_n({static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + GPUArray c_n({static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + + // Allocate intermediate buffers + GPUArray gates({static_cast(batch_size), static_cast(seq_len), static_cast(4 * hidden_size)}, DataType::Float32); + GPUArray h_curr({static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + GPUArray c_curr({static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + GPUArray h_next({static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + GPUArray c_next({static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + + // Get stream for CUDA Graph compatibility + cudaStream_t stream = internal::get_capture_stream(); + + // Initialize h0, c0 + int state_size = batch_size * hidden_size; + int block_init = 256; + int grid_init = (state_size + block_init - 1) / block_init; + + if (h0.size() > 0) { + copy_f32_kernel<<>>( + static_cast(h0.data()), + static_cast(h_curr.data()), state_size); + } else { + zero_init_f32_kernel<<>>( + static_cast(h_curr.data()), state_size); + } + + if (c0.size() > 0) { + copy_f32_kernel<<>>( + static_cast(c0.data()), + static_cast(c_curr.data()), state_size); + } else { + zero_init_f32_kernel<<>>( + static_cast(c_curr.data()), state_size); + } + + // Precompute all gates: W_ih @ x + b_ih + b_hh + { + int gate_size = 4 * hidden_size; + dim3 block(256); + dim3 grid((gate_size + 255) / 256, seq_len, batch_size); + + lstm_precompute_gates_f32_kernel<<>>( + static_cast(x.data()), + static_cast(W_ih.data()), + static_cast(b_ih.data()), + static_cast(b_hh.data()), + static_cast(gates.data()), + batch_size, seq_len, input_size, hidden_size); + } + + sync_and_check("lstm_precompute_gates failed"); + + // Process sequence + dim3 block_step(256); + dim3 grid_step((hidden_size + 255) / 256, batch_size); + + for (int t = 0; t < seq_len; ++t) { + int seq_idx = reverse ? (seq_len - 1 - t) : t; + + // Adjust for correct memory layout [batch, seq, 4*hidden] + size_t gates_offset = static_cast(seq_idx) * 4 * hidden_size; + + lstm_step_f32_kernel<<>>( + static_cast(gates.data()) + gates_offset, + static_cast(h_curr.data()), + static_cast(c_curr.data()), + static_cast(W_hh.data()), + static_cast(h_next.data()), + static_cast(c_next.data()), + batch_size, hidden_size); + + // Copy to output using kernel (strided copy) + lstm_copy_to_output_f32_kernel<<>>( + static_cast(h_next.data()), + static_cast(output.data()), + batch_size, seq_len, hidden_size, seq_idx); + + // Swap buffers + std::swap(h_curr, h_next); + std::swap(c_curr, c_next); + } + + sync_and_check("lstm_forward failed"); + + // Copy final states using kernel + copy_f32_kernel<<>>( + static_cast(h_curr.data()), + static_cast(h_n.data()), state_size); + copy_f32_kernel<<>>( + static_cast(c_curr.data()), + static_cast(c_n.data()), state_size); + + sync_and_check("lstm_forward final copy failed"); + + return std::make_tuple(std::move(output), std::move(h_n), std::move(c_n)); +} + +// ============================================================================ +// LSTM Bidirectional +// ============================================================================ + +/** + * Bidirectional LSTM. + * + * Args: + * x: input [batch, seq_len, input_size] + * W_ih_fwd, W_hh_fwd, b_ih_fwd, b_hh_fwd: forward LSTM weights + * W_ih_bwd, W_hh_bwd, b_ih_bwd, b_hh_bwd: backward LSTM weights + * + * Returns: + * output: [batch, seq_len, 2*hidden_size] (concatenated forward and backward) + * h_n: [2, batch, hidden_size] + * c_n: [2, batch, hidden_size] + */ +std::tuple lstm_bidirectional( + const GPUArray& x, + const GPUArray& W_ih_fwd, const GPUArray& W_hh_fwd, + const GPUArray& b_ih_fwd, const GPUArray& b_hh_fwd, + const GPUArray& W_ih_bwd, const GPUArray& W_hh_bwd, + const GPUArray& b_ih_bwd, const GPUArray& b_hh_bwd +) { + int batch_size = static_cast(x.shape()[0]); + int seq_len = static_cast(x.shape()[1]); + int hidden_size = static_cast(W_hh_fwd.shape()[1]); + + // Get stream for CUDA Graph compatibility + cudaStream_t stream = internal::get_capture_stream(); + + // Empty initial states (zero-sized arrays) + GPUArray empty_h0_fwd({0}, DataType::Float32); + GPUArray empty_c0_fwd({0}, DataType::Float32); + GPUArray empty_h0_bwd({0}, DataType::Float32); + GPUArray empty_c0_bwd({0}, DataType::Float32); + + // Forward pass + auto [out_fwd, h_fwd, c_fwd] = lstm_forward( + x, W_ih_fwd, W_hh_fwd, b_ih_fwd, b_hh_fwd, empty_h0_fwd, empty_c0_fwd, false); + + // Backward pass + auto [out_bwd, h_bwd, c_bwd] = lstm_forward( + x, W_ih_bwd, W_hh_bwd, b_ih_bwd, b_hh_bwd, empty_h0_bwd, empty_c0_bwd, true); + + // Concatenate outputs: [batch, seq_len, 2*hidden] + GPUArray output({static_cast(batch_size), static_cast(seq_len), static_cast(2 * hidden_size)}, DataType::Float32); + + // Use concatenation kernel (single kernel launch instead of nested loops) + { + dim3 block(256); + dim3 grid((hidden_size + 255) / 256, seq_len, batch_size); + + lstm_concat_bidirectional_f32_kernel<<>>( + static_cast(out_fwd.data()), + static_cast(out_bwd.data()), + static_cast(output.data()), + batch_size, seq_len, hidden_size); + } + + // Stack h_n, c_n: [2, batch, hidden] + GPUArray h_n({2, static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + GPUArray c_n({2, static_cast(batch_size), static_cast(hidden_size)}, DataType::Float32); + + int state_size = batch_size * hidden_size; + int block_copy = 256; + int grid_copy = (state_size + block_copy - 1) / block_copy; + + // Copy h_n[0] = h_fwd, h_n[1] = h_bwd + copy_f32_kernel<<>>( + static_cast(h_fwd.data()), + static_cast(h_n.data()), state_size); + copy_f32_kernel<<>>( + static_cast(h_bwd.data()), + static_cast(h_n.data()) + state_size, state_size); + + // Copy c_n[0] = c_fwd, c_n[1] = c_bwd + copy_f32_kernel<<>>( + static_cast(c_fwd.data()), + static_cast(c_n.data()), state_size); + copy_f32_kernel<<>>( + static_cast(c_bwd.data()), + static_cast(c_n.data()) + state_size, state_size); + + sync_and_check("lstm_bidirectional failed"); + + return std::make_tuple(std::move(output), std::move(h_n), std::move(c_n)); +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/recurrent/lstm_kernels.cuh b/native/ops/nn/recurrent/lstm_kernels.cuh new file mode 100644 index 0000000..e2e4864 --- /dev/null +++ b/native/ops/nn/recurrent/lstm_kernels.cuh @@ -0,0 +1,370 @@ +/** + * LSTM Kernel Definitions + * + * Implements LSTM cell computation for TTS and other sequence models. + * Supports unidirectional and bidirectional modes. + * + * LSTM equations: + * i_t = sigmoid(W_ii @ x_t + b_ii + W_hi @ h_{t-1} + b_hi) + * f_t = sigmoid(W_if @ x_t + b_if + W_hf @ h_{t-1} + b_hf) + * g_t = tanh(W_ig @ x_t + b_ig + W_hg @ h_{t-1} + b_hg) + * o_t = sigmoid(W_io @ x_t + b_io + W_ho @ h_{t-1} + b_ho) + * c_t = f_t * c_{t-1} + i_t * g_t + * h_t = o_t * tanh(c_t) + * + * PyTorch packs weights as: + * W_ih: [4*hidden_size, input_size] (i, f, g, o gates) + * W_hh: [4*hidden_size, hidden_size] + * b_ih: [4*hidden_size] + * b_hh: [4*hidden_size] + */ + +#pragma once + +#include +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace nn { + +// ============================================================================ +// Device functions (prefixed to avoid collision with activation_kernels.cuh) +// ============================================================================ + +__device__ __forceinline__ float lstm_sigmoid(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float lstm_tanh(float x) { + return tanhf(x); +} + +// ============================================================================ +// LSTM Cell Kernel - Single timestep, single batch element +// ============================================================================ + +/** + * Compute LSTM gates for a single timestep. + * + * Input: + * gates_precomputed: W_ih @ x_t + b_ih + b_hh [4*hidden_size] + * h_prev: previous hidden state [hidden_size] + * c_prev: previous cell state [hidden_size] + * W_hh: hidden-to-hidden weights [4*hidden_size, hidden_size] + * + * Output: + * h_out: new hidden state [hidden_size] + * c_out: new cell state [hidden_size] + */ +__global__ void lstm_cell_f32_kernel( + const float* __restrict__ gates_precomputed, // [batch, 4*hidden] + const float* __restrict__ h_prev, // [batch, hidden] + const float* __restrict__ c_prev, // [batch, hidden] + const float* __restrict__ W_hh, // [4*hidden, hidden] + float* __restrict__ h_out, // [batch, hidden] + float* __restrict__ c_out, // [batch, hidden] + int batch_size, + int hidden_size +) { + int batch_idx = blockIdx.y; + int h_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (batch_idx >= batch_size || h_idx >= hidden_size) return; + + // Compute W_hh @ h_prev contribution for this hidden unit + float gate_i = gates_precomputed[batch_idx * 4 * hidden_size + h_idx]; + float gate_f = gates_precomputed[batch_idx * 4 * hidden_size + hidden_size + h_idx]; + float gate_g = gates_precomputed[batch_idx * 4 * hidden_size + 2 * hidden_size + h_idx]; + float gate_o = gates_precomputed[batch_idx * 4 * hidden_size + 3 * hidden_size + h_idx]; + + // Add W_hh @ h_prev + for (int k = 0; k < hidden_size; ++k) { + float h_k = h_prev[batch_idx * hidden_size + k]; + gate_i += W_hh[h_idx * hidden_size + k] * h_k; + gate_f += W_hh[(hidden_size + h_idx) * hidden_size + k] * h_k; + gate_g += W_hh[(2 * hidden_size + h_idx) * hidden_size + k] * h_k; + gate_o += W_hh[(3 * hidden_size + h_idx) * hidden_size + k] * h_k; + } + + // Apply activations + float i = lstm_sigmoid(gate_i); + float f = lstm_sigmoid(gate_f); + float g = lstm_tanh(gate_g); + float o = lstm_sigmoid(gate_o); + + // Update cell state + float c_prev_val = c_prev[batch_idx * hidden_size + h_idx]; + float c_new = f * c_prev_val + i * g; + + // Compute hidden state + float h_new = o * lstm_tanh(c_new); + + // Store outputs + c_out[batch_idx * hidden_size + h_idx] = c_new; + h_out[batch_idx * hidden_size + h_idx] = h_new; +} + +// ============================================================================ +// Optimized LSTM Cell - Uses shared memory for W_hh @ h_prev +// ============================================================================ + +template +__global__ void lstm_cell_tiled_f32_kernel( + const float* __restrict__ gates_precomputed, // [batch, 4*hidden] + const float* __restrict__ h_prev, // [batch, hidden] + const float* __restrict__ c_prev, // [batch, hidden] + const float* __restrict__ W_hh, // [4*hidden, hidden] + float* __restrict__ h_out, // [batch, hidden] + float* __restrict__ c_out, // [batch, hidden] + int batch_size, + int hidden_size +) { + extern __shared__ float smem[]; + float* h_shared = smem; // [TILE_K] + + int batch_idx = blockIdx.y; + int h_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (batch_idx >= batch_size) return; + + // Initialize gate accumulators from precomputed values + float gate_i = 0.0f, gate_f = 0.0f, gate_g = 0.0f, gate_o = 0.0f; + + if (h_idx < hidden_size) { + gate_i = gates_precomputed[batch_idx * 4 * hidden_size + h_idx]; + gate_f = gates_precomputed[batch_idx * 4 * hidden_size + hidden_size + h_idx]; + gate_g = gates_precomputed[batch_idx * 4 * hidden_size + 2 * hidden_size + h_idx]; + gate_o = gates_precomputed[batch_idx * 4 * hidden_size + 3 * hidden_size + h_idx]; + } + + // Tiled computation of W_hh @ h_prev + for (int tile_start = 0; tile_start < hidden_size; tile_start += TILE_K) { + // Load h_prev tile to shared memory + int load_idx = tile_start + threadIdx.x; + if (threadIdx.x < TILE_K && load_idx < hidden_size) { + h_shared[threadIdx.x] = h_prev[batch_idx * hidden_size + load_idx]; + } + __syncthreads(); + + // Compute partial sums + if (h_idx < hidden_size) { + int tile_end = min(TILE_K, hidden_size - tile_start); + for (int k = 0; k < tile_end; ++k) { + float h_k = h_shared[k]; + int k_global = tile_start + k; + gate_i += W_hh[h_idx * hidden_size + k_global] * h_k; + gate_f += W_hh[(hidden_size + h_idx) * hidden_size + k_global] * h_k; + gate_g += W_hh[(2 * hidden_size + h_idx) * hidden_size + k_global] * h_k; + gate_o += W_hh[(3 * hidden_size + h_idx) * hidden_size + k_global] * h_k; + } + } + __syncthreads(); + } + + if (h_idx >= hidden_size) return; + + // Apply activations + float i = lstm_sigmoid(gate_i); + float f = lstm_sigmoid(gate_f); + float g = lstm_tanh(gate_g); + float o = lstm_sigmoid(gate_o); + + // Update cell state + float c_prev_val = c_prev[batch_idx * hidden_size + h_idx]; + float c_new = f * c_prev_val + i * g; + + // Compute hidden state + float h_new = o * lstm_tanh(c_new); + + // Store outputs + c_out[batch_idx * hidden_size + h_idx] = c_new; + h_out[batch_idx * hidden_size + h_idx] = h_new; +} + +// ============================================================================ +// LSTM Forward - Process full sequence +// ============================================================================ + +/** + * LSTM forward pass for full sequence. + * + * Processes sequence timestep by timestep. + * For bidirectional, call twice (forward and reverse). + * + * Input: + * x: input sequence [batch, seq_len, input_size] + * W_ih: input-to-hidden weights [4*hidden_size, input_size] + * W_hh: hidden-to-hidden weights [4*hidden_size, hidden_size] + * b_ih: input bias [4*hidden_size] + * b_hh: hidden bias [4*hidden_size] + * h0: initial hidden state [batch, hidden_size] (can be nullptr for zeros) + * c0: initial cell state [batch, hidden_size] (can be nullptr for zeros) + * reverse: if true, process sequence in reverse order + * + * Output: + * output: hidden states for all timesteps [batch, seq_len, hidden_size] + * h_n: final hidden state [batch, hidden_size] + * c_n: final cell state [batch, hidden_size] + */ + +// Kernel to precompute W_ih @ x + b_ih + b_hh for all timesteps +// This is a batched GEMM: [4*H, I] @ [B, S, I]^T -> [B, S, 4*H] +__global__ void lstm_precompute_gates_f32_kernel( + const float* __restrict__ x, // [batch, seq_len, input_size] + const float* __restrict__ W_ih, // [4*hidden, input_size] + const float* __restrict__ b_ih, // [4*hidden] + const float* __restrict__ b_hh, // [4*hidden] + float* __restrict__ gates, // [batch, seq_len, 4*hidden] + int batch_size, + int seq_len, + int input_size, + int hidden_size +) { + int batch_idx = blockIdx.z; + int seq_idx = blockIdx.y; + int gate_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (batch_idx >= batch_size || seq_idx >= seq_len || gate_idx >= 4 * hidden_size) return; + + // Compute W_ih @ x[batch, seq, :] + float sum = 0.0f; + const float* x_ptr = x + batch_idx * seq_len * input_size + seq_idx * input_size; + const float* w_ptr = W_ih + gate_idx * input_size; + + for (int i = 0; i < input_size; ++i) { + sum += w_ptr[i] * x_ptr[i]; + } + + // Add biases + sum += b_ih[gate_idx] + b_hh[gate_idx]; + + // Store + gates[batch_idx * seq_len * 4 * hidden_size + seq_idx * 4 * hidden_size + gate_idx] = sum; +} + +// Fused LSTM cell that operates on precomputed gates +__global__ void lstm_step_f32_kernel( + const float* __restrict__ gates, // [batch, 4*hidden] precomputed for this timestep + const float* __restrict__ h_prev, // [batch, hidden] + const float* __restrict__ c_prev, // [batch, hidden] + const float* __restrict__ W_hh, // [4*hidden, hidden] + float* __restrict__ h_out, // [batch, hidden] + float* __restrict__ c_out, // [batch, hidden] + int batch_size, + int hidden_size +) { + int batch_idx = blockIdx.y; + int h_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (batch_idx >= batch_size || h_idx >= hidden_size) return; + + // Load precomputed gates + int base = batch_idx * 4 * hidden_size; + float gate_i = gates[base + h_idx]; + float gate_f = gates[base + hidden_size + h_idx]; + float gate_g = gates[base + 2 * hidden_size + h_idx]; + float gate_o = gates[base + 3 * hidden_size + h_idx]; + + // Add W_hh @ h_prev contribution + const float* W_hh_i = W_hh + h_idx * hidden_size; + const float* W_hh_f = W_hh + (hidden_size + h_idx) * hidden_size; + const float* W_hh_g = W_hh + (2 * hidden_size + h_idx) * hidden_size; + const float* W_hh_o = W_hh + (3 * hidden_size + h_idx) * hidden_size; + const float* h_ptr = h_prev + batch_idx * hidden_size; + + for (int k = 0; k < hidden_size; ++k) { + float h_k = h_ptr[k]; + gate_i += W_hh_i[k] * h_k; + gate_f += W_hh_f[k] * h_k; + gate_g += W_hh_g[k] * h_k; + gate_o += W_hh_o[k] * h_k; + } + + // Apply activations + float i = lstm_sigmoid(gate_i); + float f = lstm_sigmoid(gate_f); + float g = lstm_tanh(gate_g); + float o = lstm_sigmoid(gate_o); + + // Update states + float c_new = f * c_prev[batch_idx * hidden_size + h_idx] + i * g; + float h_new = o * lstm_tanh(c_new); + + c_out[batch_idx * hidden_size + h_idx] = c_new; + h_out[batch_idx * hidden_size + h_idx] = h_new; +} + +// Zero initialization kernel +__global__ void zero_init_f32_kernel(float* data, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + data[idx] = 0.0f; + } +} + +// Simple copy kernel (replaces cudaMemcpy DtoD) +__global__ void copy_f32_kernel( + const float* __restrict__ src, + float* __restrict__ dst, + int size +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + dst[idx] = src[idx]; + } +} + +// Strided copy kernel for LSTM output +// Copies h_next[batch, hidden] to output[batch, seq_idx, hidden] +// output layout: [batch, seq_len, hidden_size] +__global__ void lstm_copy_to_output_f32_kernel( + const float* __restrict__ h_next, // [batch, hidden_size] + float* __restrict__ output, // [batch, seq_len, hidden_size] + int batch_size, + int seq_len, + int hidden_size, + int seq_idx +) { + int batch_idx = blockIdx.y; + int h_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (batch_idx >= batch_size || h_idx >= hidden_size) return; + + // src: h_next[batch_idx, h_idx] + int src_offset = batch_idx * hidden_size + h_idx; + // dst: output[batch_idx, seq_idx, h_idx] + int dst_offset = batch_idx * seq_len * hidden_size + seq_idx * hidden_size + h_idx; + + output[dst_offset] = h_next[src_offset]; +} + +// Concatenation kernel for bidirectional LSTM output +// Copies fwd[batch, seq, hidden] and bwd[batch, seq, hidden] to output[batch, seq, 2*hidden] +__global__ void lstm_concat_bidirectional_f32_kernel( + const float* __restrict__ fwd_out, // [batch, seq_len, hidden_size] + const float* __restrict__ bwd_out, // [batch, seq_len, hidden_size] + float* __restrict__ output, // [batch, seq_len, 2*hidden_size] + int batch_size, + int seq_len, + int hidden_size +) { + int batch_idx = blockIdx.z; + int seq_idx = blockIdx.y; + int h_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (batch_idx >= batch_size || seq_idx >= seq_len || h_idx >= hidden_size) return; + + int src_offset = batch_idx * seq_len * hidden_size + seq_idx * hidden_size + h_idx; + int dst_offset_fwd = batch_idx * seq_len * 2 * hidden_size + seq_idx * 2 * hidden_size + h_idx; + int dst_offset_bwd = dst_offset_fwd + hidden_size; + + output[dst_offset_fwd] = fwd_out[src_offset]; + output[dst_offset_bwd] = bwd_out[src_offset]; +} + +} // namespace nn +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/ops.cuh b/native/ops/ops.cuh index bf58f9e..01a55b1 100644 --- a/native/ops/ops.cuh +++ b/native/ops/ops.cuh @@ -10,6 +10,7 @@ #pragma once #include "../core/memory.hpp" +#include namespace pygpukit { namespace ops { @@ -525,5 +526,37 @@ int sample_token_gpu( // Set random seed for reproducible sampling void set_sampling_seed(unsigned int seed); +// ============================================================================ +// LSTM (Long Short-Term Memory) +// ============================================================================ + +// LSTM forward pass (unidirectional) +// x: [batch, seq_len, input_size] +// W_ih: [4*hidden_size, input_size], W_hh: [4*hidden_size, hidden_size] +// b_ih, b_hh: [4*hidden_size] +// h0, c0: [batch, hidden_size] or empty for zeros +// reverse: process sequence in reverse order +// Returns: (output[batch, seq_len, hidden], h_n[batch, hidden], c_n[batch, hidden]) +std::tuple lstm_forward( + const GPUArray& x, + const GPUArray& W_ih, + const GPUArray& W_hh, + const GPUArray& b_ih, + const GPUArray& b_hh, + const GPUArray& h0, + const GPUArray& c0, + bool reverse = false +); + +// Bidirectional LSTM +// Returns: (output[batch, seq_len, 2*hidden], h_n[2, batch, hidden], c_n[2, batch, hidden]) +std::tuple lstm_bidirectional( + const GPUArray& x, + const GPUArray& W_ih_fwd, const GPUArray& W_hh_fwd, + const GPUArray& b_ih_fwd, const GPUArray& b_hh_fwd, + const GPUArray& W_ih_bwd, const GPUArray& W_hh_bwd, + const GPUArray& b_ih_bwd, const GPUArray& b_hh_bwd +); + } // namespace ops } // namespace pygpukit diff --git a/src/pygpukit/__init__.py b/src/pygpukit/__init__.py index 42553f8..cf8e7b1 100644 --- a/src/pygpukit/__init__.py +++ b/src/pygpukit/__init__.py @@ -53,6 +53,8 @@ layernorm, linear_bias_gelu, log, + lstm_bidirectional, + lstm_forward, matmul, max, mean, @@ -164,6 +166,8 @@ "gelu", "layernorm", "log", + "lstm_bidirectional", + "lstm_forward", "matmul", "mul", "neg", diff --git a/src/pygpukit/ops/basic.py b/src/pygpukit/ops/basic.py index e60c9ae..9c33d4d 100644 --- a/src/pygpukit/ops/basic.py +++ b/src/pygpukit/ops/basic.py @@ -111,6 +111,8 @@ bias_add_inplace, gelu, layernorm, + lstm_bidirectional, + lstm_forward, rmsnorm, rope_inplace, rope_inplace_f32table, @@ -282,6 +284,9 @@ "rope_inplace_f32table", "split_qkv_batch", "slice_rows_range_ptr", + # LSTM + "lstm_forward", + "lstm_bidirectional", # Embedding & KV Cache "embedding_lookup", "embedding_lookup_ptr", diff --git a/src/pygpukit/ops/nn.py b/src/pygpukit/ops/nn.py index 1637abf..ecf6f8f 100644 --- a/src/pygpukit/ops/nn.py +++ b/src/pygpukit/ops/nn.py @@ -883,3 +883,134 @@ def slice_rows_range_ptr( start_pos_buf._get_native(), count, ) + + +# ============================================================================= +# LSTM (Recurrent) Operations +# ============================================================================= + + +def lstm_forward( + x: GPUArray, + W_ih: GPUArray, + W_hh: GPUArray, + b_ih: GPUArray, + b_hh: GPUArray, + h0: GPUArray | None = None, + c0: GPUArray | None = None, + reverse: bool = False, +) -> tuple[GPUArray, GPUArray, GPUArray]: + """LSTM forward pass (unidirectional). + + Implements the standard LSTM equations: + i_t = sigmoid(W_ii @ x_t + b_ii + W_hi @ h_{t-1} + b_hi) + f_t = sigmoid(W_if @ x_t + b_if + W_hf @ h_{t-1} + b_hf) + g_t = tanh(W_ig @ x_t + b_ig + W_hg @ h_{t-1} + b_hg) + o_t = sigmoid(W_io @ x_t + b_io + W_ho @ h_{t-1} + b_ho) + c_t = f_t * c_{t-1} + i_t * g_t + h_t = o_t * tanh(c_t) + + Args: + x: Input sequence [batch, seq_len, input_size]. + W_ih: Input-to-hidden weights [4*hidden_size, input_size]. + W_hh: Hidden-to-hidden weights [4*hidden_size, hidden_size]. + b_ih: Input bias [4*hidden_size]. + b_hh: Hidden bias [4*hidden_size]. + h0: Initial hidden state [batch, hidden_size]. If None, zeros. + c0: Initial cell state [batch, hidden_size]. If None, zeros. + reverse: If True, process sequence in reverse order. + + Returns: + Tuple of (output, h_n, c_n): + output: Hidden states [batch, seq_len, hidden_size] + h_n: Final hidden state [batch, hidden_size] + c_n: Final cell state [batch, hidden_size] + """ + from pygpukit.core.backend import get_backend, get_native_module + + backend = get_backend() + if not backend.is_available(): + raise RuntimeError("lstm_forward requires GPU backend") + + native = get_native_module() + + # Create zero-sized arrays for None states + if h0 is None: + h0_native = native.GPUArray([0], native.Float32) + else: + h0_native = h0._get_native() + + if c0 is None: + c0_native = native.GPUArray([0], native.Float32) + else: + c0_native = c0._get_native() + + output_native, h_n_native, c_n_native = native.lstm_forward( + x._get_native(), + W_ih._get_native(), + W_hh._get_native(), + b_ih._get_native(), + b_hh._get_native(), + h0_native, + c0_native, + reverse, + ) + + return ( + GPUArray._wrap_native(output_native), + GPUArray._wrap_native(h_n_native), + GPUArray._wrap_native(c_n_native), + ) + + +def lstm_bidirectional( + x: GPUArray, + W_ih_fwd: GPUArray, + W_hh_fwd: GPUArray, + b_ih_fwd: GPUArray, + b_hh_fwd: GPUArray, + W_ih_bwd: GPUArray, + W_hh_bwd: GPUArray, + b_ih_bwd: GPUArray, + b_hh_bwd: GPUArray, +) -> tuple[GPUArray, GPUArray, GPUArray]: + """Bidirectional LSTM. + + Runs forward and backward LSTM passes and concatenates the outputs. + + Args: + x: Input sequence [batch, seq_len, input_size]. + W_ih_fwd, W_hh_fwd, b_ih_fwd, b_hh_fwd: Forward LSTM weights. + W_ih_bwd, W_hh_bwd, b_ih_bwd, b_hh_bwd: Backward LSTM weights. + + Returns: + Tuple of (output, h_n, c_n): + output: Concatenated hidden states [batch, seq_len, 2*hidden_size] + h_n: Stacked final hidden states [2, batch, hidden_size] + c_n: Stacked final cell states [2, batch, hidden_size] + """ + from pygpukit.core.backend import get_backend, get_native_module + + backend = get_backend() + if not backend.is_available(): + raise RuntimeError("lstm_bidirectional requires GPU backend") + + native = get_native_module() + + output_native, h_n_native, c_n_native = native.lstm_bidirectional( + x._get_native(), + W_ih_fwd._get_native(), + W_hh_fwd._get_native(), + b_ih_fwd._get_native(), + b_hh_fwd._get_native(), + W_ih_bwd._get_native(), + W_hh_bwd._get_native(), + b_ih_bwd._get_native(), + b_hh_bwd._get_native(), + ) + + return ( + GPUArray._wrap_native(output_native), + GPUArray._wrap_native(h_n_native), + GPUArray._wrap_native(c_n_native), + ) diff --git a/src/pygpukit/tts/__init__.py b/src/pygpukit/tts/__init__.py new file mode 100644 index 0000000..e25a75b --- /dev/null +++ b/src/pygpukit/tts/__init__.py @@ -0,0 +1,46 @@ +"""PyGPUkit Text-to-Speech module. + +Provides GPU-accelerated text-to-speech synthesis using neural network models. + +Supported Models: + - Kokoro-82M: StyleTTS2-based model with 82M parameters + +Example: + >>> from pygpukit.tts import KokoroModel + >>> model = KokoroModel.from_pretrained("hexgrad/Kokoro-82M") + >>> audio = model.synthesize("Hello, this is PyGPUkit TTS!") + >>> audio.to_wav("output.wav") +""" + +from pygpukit.tts.kokoro import ( + KokoroConfig, + KokoroModel, + KokoroTokenizer, + SynthesisResult, + concatenate_audio, + from_wav, + list_available_voices, + load_kokoro_weights, + load_voice_embedding, + resample_audio, + to_wav, +) + +__all__ = [ + # Model + "KokoroModel", + "SynthesisResult", + # Config + "KokoroConfig", + # Tokenizer + "KokoroTokenizer", + # Loader + "load_kokoro_weights", + "load_voice_embedding", + "list_available_voices", + # Audio + "to_wav", + "from_wav", + "resample_audio", + "concatenate_audio", +] diff --git a/src/pygpukit/tts/kokoro/__init__.py b/src/pygpukit/tts/kokoro/__init__.py new file mode 100644 index 0000000..a905a73 --- /dev/null +++ b/src/pygpukit/tts/kokoro/__init__.py @@ -0,0 +1,41 @@ +"""Kokoro-82M TTS model implementation. + +Kokoro is a StyleTTS2-based text-to-speech model with 82M parameters. +It achieves high-quality speech synthesis with a compact architecture. + +Example: + >>> from pygpukit.tts.kokoro import KokoroModel + >>> model = KokoroModel.from_pretrained("hexgrad/Kokoro-82M") + >>> audio = model.synthesize("Hello, world!") + >>> audio.to_wav("output.wav") +""" + +from pygpukit.tts.kokoro.audio import concatenate_audio, from_wav, resample_audio, to_wav +from pygpukit.tts.kokoro.config import KokoroConfig +from pygpukit.tts.kokoro.loader import ( + list_available_voices, + load_kokoro_weights, + load_voice_embedding, +) +from pygpukit.tts.kokoro.model import KokoroModel, SynthesisResult +from pygpukit.tts.kokoro.text import KokoroTokenizer, TokenizerOutput + +__all__ = [ + # Model + "KokoroModel", + "SynthesisResult", + # Config + "KokoroConfig", + # Tokenizer + "KokoroTokenizer", + "TokenizerOutput", + # Loader + "load_kokoro_weights", + "load_voice_embedding", + "list_available_voices", + # Audio + "to_wav", + "from_wav", + "resample_audio", + "concatenate_audio", +] diff --git a/src/pygpukit/tts/kokoro/audio.py b/src/pygpukit/tts/kokoro/audio.py new file mode 100644 index 0000000..6698923 --- /dev/null +++ b/src/pygpukit/tts/kokoro/audio.py @@ -0,0 +1,255 @@ +"""Audio utilities for Kokoro TTS. + +Provides: +- WAV file export +- Audio format conversion +- Playback utilities +""" + +from __future__ import annotations + +import struct +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from pygpukit.core.array import GPUArray + +if TYPE_CHECKING: + from pygpukit.ops.audio import AudioBuffer + + +def to_wav( + audio: AudioBuffer | GPUArray | np.ndarray, + path: str | Path, + sample_rate: int = 24000, + normalize: bool = True, +) -> None: + """Export audio to WAV file. + + Writes a standard 16-bit PCM WAV file. + + Args: + audio: Audio data (AudioBuffer, GPUArray, or numpy array) + path: Output file path + sample_rate: Sample rate in Hz (default: 24000 for Kokoro) + normalize: Whether to normalize audio to prevent clipping + + Example: + >>> from pygpukit.tts.kokoro import KokoroModel + >>> model = KokoroModel.from_pretrained("hexgrad/Kokoro-82M") + >>> result = model.synthesize("Hello!") + >>> to_wav(result.audio, "output.wav") + """ + # Convert to numpy array + if hasattr(audio, "data") and hasattr(audio, "sample_rate"): + # AudioBuffer + samples = audio.data.to_numpy() # type: ignore + sample_rate = audio.sample_rate # type: ignore + elif isinstance(audio, GPUArray): + samples = audio.to_numpy() + elif isinstance(audio, np.ndarray): + samples = audio + else: + raise TypeError(f"Unsupported audio type: {type(audio)}") + + # Ensure float32 + samples = samples.astype(np.float32) + + # Flatten if needed + if samples.ndim > 1: + samples = samples.flatten() + + # Normalize to prevent clipping + if normalize: + max_val = np.abs(samples).max() + if max_val > 0: + samples = samples / max_val * 0.95 + + # Convert to 16-bit PCM + samples_int16: np.ndarray = (samples * 32767).astype(np.int16) + + # Write WAV file + path = Path(path) + with open(path, "wb") as f: + # RIFF header + f.write(b"RIFF") + # File size (will be filled later) + file_size_pos = f.tell() + f.write(struct.pack(" tuple[np.ndarray, int]: + """Load audio from WAV file. + + Args: + path: Path to WAV file + + Returns: + Tuple of (samples as float32, sample_rate) + + Example: + >>> samples, sr = from_wav("input.wav") + >>> print(f"Duration: {len(samples) / sr:.2f}s") + """ + path = Path(path) + + with open(path, "rb") as f: + # Read RIFF header + riff = f.read(4) + if riff != b"RIFF": + raise ValueError("Not a valid WAV file (missing RIFF header)") + + f.read(4) # File size + wave = f.read(4) + if wave != b"WAVE": + raise ValueError("Not a valid WAV file (missing WAVE header)") + + # Read chunks + sample_rate = 44100 + num_channels = 1 + bits_per_sample = 16 + audio_data = None + + while True: + chunk_id = f.read(4) + if len(chunk_id) < 4: + break + + chunk_size = struct.unpack(" 0: + f.read(extra) + + if audio_format != 1: + raise ValueError(f"Unsupported audio format: {audio_format}") + + elif chunk_id == b"data": + audio_data = f.read(chunk_size) + + else: + # Skip unknown chunks + f.read(chunk_size) + + if audio_data is None: + raise ValueError("No audio data found in WAV file") + + # Convert to numpy array + if bits_per_sample == 16: + samples = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32767 + elif bits_per_sample == 8: + samples = (np.frombuffer(audio_data, dtype=np.uint8).astype(np.float32) - 128) / 128 + elif bits_per_sample == 32: + samples = np.frombuffer(audio_data, dtype=np.int32).astype(np.float32) / 2147483647 + else: + raise ValueError(f"Unsupported bits per sample: {bits_per_sample}") + + # Convert stereo to mono if needed + if num_channels == 2: + samples = samples.reshape(-1, 2).mean(axis=1) + elif num_channels > 2: + samples = samples.reshape(-1, num_channels).mean(axis=1) + + return samples, sample_rate + + +def resample_audio( + samples: np.ndarray, + orig_sr: int, + target_sr: int, +) -> np.ndarray: + """Resample audio to target sample rate. + + Simple linear interpolation resampling. + For high-quality resampling, use scipy or librosa. + + Args: + samples: Audio samples + orig_sr: Original sample rate + target_sr: Target sample rate + + Returns: + Resampled audio + """ + if orig_sr == target_sr: + return samples + + # Calculate new length + duration = len(samples) / orig_sr + new_length = int(duration * target_sr) + + # Linear interpolation + old_indices = np.linspace(0, len(samples) - 1, new_length) + new_samples = np.interp(old_indices, np.arange(len(samples)), samples) + + return new_samples.astype(np.float32) + + +def concatenate_audio( + audio_list: list[np.ndarray | GPUArray], + gap_samples: int = 0, +) -> np.ndarray: + """Concatenate multiple audio segments. + + Args: + audio_list: List of audio arrays + gap_samples: Number of silence samples between segments + + Returns: + Concatenated audio + """ + segments = [] + for audio in audio_list: + if isinstance(audio, GPUArray): + audio = audio.to_numpy() + segments.append(audio.flatten()) + + if gap_samples > 0: + segments.append(np.zeros(gap_samples, dtype=np.float32)) + + # Remove trailing gap + if gap_samples > 0 and segments: + segments = segments[:-1] + + return np.concatenate(segments) if segments else np.array([], dtype=np.float32) + + +__all__ = [ + "to_wav", + "from_wav", + "resample_audio", + "concatenate_audio", +] diff --git a/src/pygpukit/tts/kokoro/config.py b/src/pygpukit/tts/kokoro/config.py new file mode 100644 index 0000000..d1fda62 --- /dev/null +++ b/src/pygpukit/tts/kokoro/config.py @@ -0,0 +1,233 @@ +"""Kokoro-82M TTS model configuration. + +Kokoro is a StyleTTS2-based TTS model with 82M parameters. +Architecture: PLBERT -> Style Encoder -> Decoder -> ISTFTNet Vocoder +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class KokoroConfig: + """Configuration for Kokoro-82M TTS model. + + Attributes: + dim_in: Input dimension for decoder (default: 64) + hidden_dim: Hidden dimension (default: 512) + style_dim: Style embedding dimension (default: 128) + n_mels: Number of mel spectrogram bins (default: 80) + n_layer: Number of decoder layers (default: 3) + n_token: Vocabulary size (default: 178) + max_dur: Maximum duration per token (default: 50) + dropout: Dropout rate (default: 0.2) + max_conv_dim: Maximum convolution dimension (default: 512) + text_encoder_kernel_size: Kernel size for text encoder (default: 5) + multispeaker: Whether model supports multiple speakers (default: True) + sample_rate: Audio sample rate in Hz (default: 24000) + """ + + # Core dimensions + dim_in: int = 64 + hidden_dim: int = 512 + style_dim: int = 128 + n_mels: int = 80 + n_layer: int = 3 + n_token: int = 178 + max_dur: int = 50 + dropout: float = 0.2 + max_conv_dim: int = 512 + text_encoder_kernel_size: int = 5 + multispeaker: bool = True + + # Audio + sample_rate: int = 24000 + + # ISTFTNet vocoder + upsample_rates: tuple[int, ...] = (10, 6) + upsample_kernel_sizes: tuple[int, ...] = (20, 12) + resblock_kernel_sizes: tuple[int, ...] = (3, 7, 11) + resblock_dilation_sizes: tuple[tuple[int, ...], ...] = ( + (1, 3, 5), + (1, 3, 5), + (1, 3, 5), + ) + upsample_initial_channel: int = 512 + gen_istft_n_fft: int = 20 + gen_istft_hop_size: int = 5 + + # PLBERT text encoder + plbert_hidden_size: int = 768 + plbert_num_attention_heads: int = 12 + plbert_intermediate_size: int = 2048 + plbert_max_position_embeddings: int = 512 + plbert_num_hidden_layers: int = 12 + plbert_dropout: float = 0.1 + + # Phoneme vocabulary (loaded from config.json) + vocab: dict[str, int] = field(default_factory=dict) + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> KokoroConfig: + """Create config from dictionary. + + Args: + config_dict: Configuration dictionary (from config.json) + + Returns: + KokoroConfig instance + """ + # Extract ISTFTNet config + istftnet = config_dict.get("istftnet", {}) + + # Extract PLBERT config + plbert = config_dict.get("plbert", {}) + + # Convert resblock_dilation_sizes to tuple of tuples + resblock_dilations = istftnet.get( + "resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + ) + resblock_dilations_tuple = tuple(tuple(d) for d in resblock_dilations) + + return cls( + # Core + dim_in=config_dict.get("dim_in", 64), + hidden_dim=config_dict.get("hidden_dim", 512), + style_dim=config_dict.get("style_dim", 128), + n_mels=config_dict.get("n_mels", 80), + n_layer=config_dict.get("n_layer", 3), + n_token=config_dict.get("n_token", 178), + max_dur=config_dict.get("max_dur", 50), + dropout=config_dict.get("dropout", 0.2), + max_conv_dim=config_dict.get("max_conv_dim", 512), + text_encoder_kernel_size=config_dict.get("text_encoder_kernel_size", 5), + multispeaker=config_dict.get("multispeaker", True), + # ISTFTNet + upsample_rates=tuple(istftnet.get("upsample_rates", [10, 6])), + upsample_kernel_sizes=tuple(istftnet.get("upsample_kernel_sizes", [20, 12])), + resblock_kernel_sizes=tuple(istftnet.get("resblock_kernel_sizes", [3, 7, 11])), + resblock_dilation_sizes=resblock_dilations_tuple, + upsample_initial_channel=istftnet.get("upsample_initial_channel", 512), + gen_istft_n_fft=istftnet.get("gen_istft_n_fft", 20), + gen_istft_hop_size=istftnet.get("gen_istft_hop_size", 5), + # PLBERT + plbert_hidden_size=plbert.get("hidden_size", 768), + plbert_num_attention_heads=plbert.get("num_attention_heads", 12), + plbert_intermediate_size=plbert.get("intermediate_size", 2048), + plbert_max_position_embeddings=plbert.get("max_position_embeddings", 512), + plbert_num_hidden_layers=plbert.get("num_hidden_layers", 12), + plbert_dropout=plbert.get("dropout", 0.1), + # Vocabulary + vocab=config_dict.get("vocab", {}), + ) + + @classmethod + def from_json(cls, json_path: str | Path) -> KokoroConfig: + """Load config from JSON file. + + Args: + json_path: Path to config.json + + Returns: + KokoroConfig instance + """ + with open(json_path, encoding="utf-8") as f: + config_dict = json.load(f) + return cls.from_dict(config_dict) + + @classmethod + def from_pretrained(cls, model_path: str | Path) -> KokoroConfig: + """Load config from pretrained model directory. + + Args: + model_path: Path to model directory containing config.json + + Returns: + KokoroConfig instance + """ + model_path = Path(model_path) + + # Check for local config.json + if model_path.is_dir(): + config_path = model_path / "config.json" + if config_path.exists(): + return cls.from_json(config_path) + + # Try HuggingFace hub + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download(repo_id=str(model_path), filename="config.json") + return cls.from_json(config_path) + except ImportError as err: + raise ImportError( + "huggingface_hub is required to download from HuggingFace. " + "Install with: pip install huggingface_hub" + ) from err + + def to_dict(self) -> dict[str, Any]: + """Convert config to dictionary.""" + return { + "dim_in": self.dim_in, + "hidden_dim": self.hidden_dim, + "style_dim": self.style_dim, + "n_mels": self.n_mels, + "n_layer": self.n_layer, + "n_token": self.n_token, + "max_dur": self.max_dur, + "dropout": self.dropout, + "max_conv_dim": self.max_conv_dim, + "text_encoder_kernel_size": self.text_encoder_kernel_size, + "multispeaker": self.multispeaker, + "sample_rate": self.sample_rate, + "istftnet": { + "upsample_rates": list(self.upsample_rates), + "upsample_kernel_sizes": list(self.upsample_kernel_sizes), + "resblock_kernel_sizes": list(self.resblock_kernel_sizes), + "resblock_dilation_sizes": [list(d) for d in self.resblock_dilation_sizes], + "upsample_initial_channel": self.upsample_initial_channel, + "gen_istft_n_fft": self.gen_istft_n_fft, + "gen_istft_hop_size": self.gen_istft_hop_size, + }, + "plbert": { + "hidden_size": self.plbert_hidden_size, + "num_attention_heads": self.plbert_num_attention_heads, + "intermediate_size": self.plbert_intermediate_size, + "max_position_embeddings": self.plbert_max_position_embeddings, + "num_hidden_layers": self.plbert_num_hidden_layers, + "dropout": self.plbert_dropout, + }, + "vocab": self.vocab, + } + + @property + def plbert_head_dim(self) -> int: + """PLBERT attention head dimension.""" + return self.plbert_hidden_size // self.plbert_num_attention_heads + + @property + def hop_length(self) -> int: + """Audio hop length (product of upsample rates * istft_hop_size).""" + hop = self.gen_istft_hop_size + for rate in self.upsample_rates: + hop *= rate + return hop + + def __repr__(self) -> str: + return ( + f"KokoroConfig(\n" + f" dim_in={self.dim_in}, hidden_dim={self.hidden_dim},\n" + f" style_dim={self.style_dim}, n_mels={self.n_mels},\n" + f" n_layer={self.n_layer}, n_token={self.n_token},\n" + f" sample_rate={self.sample_rate},\n" + f" plbert_layers={self.plbert_num_hidden_layers},\n" + f" upsample_rates={self.upsample_rates}\n" + f")" + ) + + +__all__ = ["KokoroConfig"] diff --git a/src/pygpukit/tts/kokoro/layers.py b/src/pygpukit/tts/kokoro/layers.py new file mode 100644 index 0000000..23f5bce --- /dev/null +++ b/src/pygpukit/tts/kokoro/layers.py @@ -0,0 +1,855 @@ +"""Neural network layers for Kokoro TTS model. + +Implements: +- Conv1d: 1D convolution layer +- PLBERTEncoder: Text encoder (BERT-style) +- StyleEncoder: Speaker style encoder +- Decoder: Mel spectrogram decoder +- ISTFTNet: Neural vocoder with ISTFT +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pygpukit.core.array import GPUArray +from pygpukit.core.factory import from_numpy + +if TYPE_CHECKING: + from pygpukit.tts.kokoro.config import KokoroConfig + + +def _get_native(): + """Get the native module.""" + try: + from pygpukit._native_loader import get_native_module + + return get_native_module() + except ImportError: + from pygpukit import _pygpukit_native + + return _pygpukit_native + + +# ============================================================================= +# Basic Layers +# ============================================================================= + + +class Linear: + """Linear layer: y = xW^T + b + + Weights are stored as [out_features, in_features]. + """ + + def __init__(self, weight: GPUArray, bias: GPUArray | None = None): + self.weight = weight + self.bias = bias + self.out_features = weight.shape[0] + self.in_features = weight.shape[1] + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass.""" + from pygpukit.ops.basic import bias_add_inplace, matmul, transpose + + weight_t = transpose(self.weight) + y = matmul(x, weight_t) + + if self.bias is not None: + bias_add_inplace(y, self.bias) + + return y + + +class LayerNorm: + """Layer normalization.""" + + def __init__(self, weight: GPUArray, bias: GPUArray | None = None, eps: float = 1e-5): + self.weight = weight + self.bias = bias + self.eps = eps + self.normalized_shape = weight.shape[0] + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass.""" + from pygpukit.ops.basic import layernorm + + return layernorm(x, self.weight, self.bias, self.eps) + + +class Conv1d: + """1D Convolution layer. + + Implements convolution using im2col + matmul for GPU efficiency. + Input shape: [batch, in_channels, length] + Output shape: [batch, out_channels, new_length] + """ + + def __init__( + self, + weight: GPUArray, # [out_channels, in_channels, kernel_size] + bias: GPUArray | None = None, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + ): + self.weight = weight + self.bias = bias + self.stride = stride + self.padding = padding + self.dilation = dilation + + self.out_channels = weight.shape[0] + self.in_channels = weight.shape[1] + self.kernel_size = weight.shape[2] + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass using im2col + matmul. + + This is a simple CPU implementation for correctness. + Can be optimized with a native CUDA kernel later. + """ + # x: [batch, in_channels, length] + batch_size = x.shape[0] + length = x.shape[2] + + # Calculate output length + effective_kernel = self.dilation * (self.kernel_size - 1) + 1 + out_length = (length + 2 * self.padding - effective_kernel) // self.stride + 1 + + # Convert to numpy for im2col (can be optimized later) + x_np = x.to_numpy() + w_np = self.weight.to_numpy() + + # Pad input + if self.padding > 0: + x_np = np.pad(x_np, ((0, 0), (0, 0), (self.padding, self.padding)), mode="constant") + + # im2col: extract patches + col = np.zeros( + (batch_size, self.in_channels, self.kernel_size, out_length), dtype=np.float32 + ) + + for i in range(self.kernel_size): + i_dilated = i * self.dilation + for j in range(out_length): + j_strided = j * self.stride + col[:, :, i, j] = x_np[:, :, j_strided + i_dilated] + + # Reshape for matmul + # col: [batch, in_channels * kernel_size, out_length] + col = col.reshape(batch_size, -1, out_length) + + # weight: [out_channels, in_channels * kernel_size] + w_reshaped = w_np.reshape(self.out_channels, -1) + + # Matmul: [batch, out_channels, out_length] + out_np = np.einsum("bkl,ok->bol", col, w_reshaped) + + # Add bias + if self.bias is not None: + bias_np = self.bias.to_numpy() + out_np = out_np + bias_np.reshape(1, -1, 1) + + return from_numpy(out_np.astype(np.float32)) + + +class LSTM: + """LSTM layer using native CUDA kernel. + + Implements unidirectional or bidirectional LSTM with PyTorch-compatible weights. + + Args: + W_ih: Input-to-hidden weights [4*hidden_size, input_size] + W_hh: Hidden-to-hidden weights [4*hidden_size, hidden_size] + b_ih: Input bias [4*hidden_size] + b_hh: Hidden bias [4*hidden_size] + bidirectional: If True, runs bidirectional LSTM + W_ih_reverse: Backward direction weights (only if bidirectional) + W_hh_reverse: Backward direction weights (only if bidirectional) + b_ih_reverse: Backward direction bias (only if bidirectional) + b_hh_reverse: Backward direction bias (only if bidirectional) + """ + + def __init__( + self, + W_ih: GPUArray, + W_hh: GPUArray, + b_ih: GPUArray, + b_hh: GPUArray, + bidirectional: bool = False, + W_ih_reverse: GPUArray | None = None, + W_hh_reverse: GPUArray | None = None, + b_ih_reverse: GPUArray | None = None, + b_hh_reverse: GPUArray | None = None, + ): + self.W_ih = W_ih + self.W_hh = W_hh + self.b_ih = b_ih + self.b_hh = b_hh + self.bidirectional = bidirectional + self.W_ih_reverse = W_ih_reverse + self.W_hh_reverse = W_hh_reverse + self.b_ih_reverse = b_ih_reverse + self.b_hh_reverse = b_hh_reverse + + # Infer dimensions from weights + self.hidden_size = W_hh.shape[1] + self.input_size = W_ih.shape[1] + + def __call__( + self, + x: GPUArray, + h0: GPUArray | None = None, + c0: GPUArray | None = None, + ) -> tuple[GPUArray, tuple[GPUArray, GPUArray]]: + """Forward pass. + + Args: + x: Input sequence [batch, seq_len, input_size] + h0: Initial hidden state [num_layers * num_directions, batch, hidden_size] + c0: Initial cell state [num_layers * num_directions, batch, hidden_size] + + Returns: + Tuple of (output, (h_n, c_n)): + output: Hidden states [batch, seq_len, hidden_size * num_directions] + h_n: Final hidden state + c_n: Final cell state + """ + from pygpukit.ops.nn import lstm_bidirectional, lstm_forward + + if self.bidirectional: + if self.W_ih_reverse is None: + raise ValueError("Bidirectional LSTM requires reverse weights") + + output, h_n, c_n = lstm_bidirectional( + x, + self.W_ih, + self.W_hh, + self.b_ih, + self.b_hh, + self.W_ih_reverse, + self.W_hh_reverse, + self.b_ih_reverse, + self.b_hh_reverse, + ) + else: + # Extract h0, c0 for single layer if provided + h0_layer = h0 + c0_layer = c0 + + output, h_n, c_n = lstm_forward( + x, self.W_ih, self.W_hh, self.b_ih, self.b_hh, h0_layer, c0_layer + ) + + return output, (h_n, c_n) + + +class ConvTranspose1d: + """1D Transposed Convolution (deconvolution) layer. + + Used for upsampling in the vocoder. + """ + + def __init__( + self, + weight: GPUArray, # [in_channels, out_channels, kernel_size] + bias: GPUArray | None = None, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + ): + self.weight = weight + self.bias = bias + self.stride = stride + self.padding = padding + self.output_padding = output_padding + + self.in_channels = weight.shape[0] + self.out_channels = weight.shape[1] + self.kernel_size = weight.shape[2] + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass.""" + # x: [batch, in_channels, length] + batch_size = x.shape[0] + length = x.shape[2] + + # Calculate output length + out_length = ( + (length - 1) * self.stride - 2 * self.padding + self.kernel_size + self.output_padding + ) + + x_np = x.to_numpy() + w_np = self.weight.to_numpy() + + # Initialize output + out_np = np.zeros((batch_size, self.out_channels, out_length), dtype=np.float32) + + # Scatter-add operation + for i in range(length): + for k in range(self.kernel_size): + out_pos = i * self.stride - self.padding + k + if 0 <= out_pos < out_length: + # out[:, :, out_pos] += x[:, :, i] @ w[:, :, k] + out_np[:, :, out_pos] += np.einsum("bi,io->bo", x_np[:, :, i], w_np[:, :, k]) + + # Add bias + if self.bias is not None: + bias_np = self.bias.to_numpy() + out_np = out_np + bias_np.reshape(1, -1, 1) + + return from_numpy(out_np) + + +# ============================================================================= +# Activation Functions +# ============================================================================= + + +def leaky_relu(x: GPUArray, negative_slope: float = 0.1) -> GPUArray: + """Leaky ReLU activation.""" + x_np = x.to_numpy() + out_np = np.where(x_np > 0, x_np, negative_slope * x_np) + return from_numpy(out_np.astype(np.float32)) + + +def tanh(x: GPUArray) -> GPUArray: + """Tanh activation.""" + from pygpukit.ops.basic import tanh as gpu_tanh + + return gpu_tanh(x) + + +# ============================================================================= +# PLBERT Text Encoder +# ============================================================================= + + +class BertSelfAttention: + """BERT self-attention layer.""" + + def __init__( + self, + query: Linear, + key: Linear, + value: Linear, + num_attention_heads: int, + attention_head_size: int, + ): + self.query = query + self.key = key + self.value = value + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = num_attention_heads * attention_head_size + + def transpose_for_scores(self, x: GPUArray) -> GPUArray: + """Reshape for multi-head attention.""" + # x: [batch, seq_len, all_head_size] + # output: [batch, num_heads, seq_len, head_size] + batch_size = x.shape[0] + seq_len = x.shape[1] + + x_np = x.to_numpy() + x_reshaped = x_np.reshape( + batch_size, seq_len, self.num_attention_heads, self.attention_head_size + ) + x_transposed = x_reshaped.transpose(0, 2, 1, 3) + return from_numpy(x_transposed.astype(np.float32)) + + def __call__(self, hidden_states: GPUArray, attention_mask: GPUArray | None = None) -> GPUArray: + """Forward pass.""" + # Compute Q, K, V + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Compute attention scores + q_np = query_layer.to_numpy() + k_np = key_layer.to_numpy() + v_np = value_layer.to_numpy() + + # Scaled dot-product attention + attention_scores = np.matmul(q_np, k_np.transpose(0, 1, 3, 2)) + attention_scores = attention_scores / np.sqrt(self.attention_head_size) + + if attention_mask is not None: + mask_np = attention_mask.to_numpy() + attention_scores = attention_scores + mask_np + + attention_probs = np.exp(attention_scores - attention_scores.max(axis=-1, keepdims=True)) + attention_probs = attention_probs / attention_probs.sum(axis=-1, keepdims=True) + + context = np.matmul(attention_probs, v_np) + + # Reshape back + batch_size = context.shape[0] + seq_len = context.shape[2] + context = context.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.all_head_size) + + return from_numpy(context.astype(np.float32)) + + +class BertLayer: + """Single BERT encoder layer.""" + + def __init__( + self, + attention: BertSelfAttention, + attention_output: Linear, + attention_norm: LayerNorm, + intermediate: Linear, + output_dense: Linear, + output_norm: LayerNorm, + ): + self.attention = attention + self.attention_output = attention_output + self.attention_norm = attention_norm + self.intermediate = intermediate + self.output_dense = output_dense + self.output_norm = output_norm + + def __call__(self, hidden_states: GPUArray, attention_mask: GPUArray | None = None) -> GPUArray: + """Forward pass.""" + from pygpukit.ops.basic import add, gelu + + # Self-attention + attention_output = self.attention(hidden_states, attention_mask) + attention_output = self.attention_output(attention_output) + hidden_states = self.attention_norm(add(attention_output, hidden_states)) + + # Feed-forward + intermediate_output = gelu(self.intermediate(hidden_states)) + layer_output = self.output_dense(intermediate_output) + hidden_states = self.output_norm(add(layer_output, hidden_states)) + + return hidden_states + + +class PLBERTEncoder: + """PLBERT text encoder for Kokoro TTS. + + BERT-style transformer encoder that converts phoneme tokens to + contextualized embeddings. + """ + + def __init__( + self, + config: KokoroConfig, + embeddings: GPUArray, # [vocab_size, hidden_size] + position_embeddings: GPUArray, # [max_position, hidden_size] + layers: list[BertLayer], + final_norm: LayerNorm | None = None, + ): + self.config = config + self.embeddings = embeddings + self.position_embeddings = position_embeddings + self.layers = layers + self.final_norm = final_norm + + def __call__( + self, + input_ids: GPUArray, # [batch, seq_len] + attention_mask: GPUArray | None = None, + ) -> GPUArray: + """Forward pass. + + Args: + input_ids: Token IDs [batch, seq_len] + attention_mask: Attention mask [batch, seq_len] (optional) + + Returns: + Hidden states [batch, seq_len, hidden_size] + """ + from pygpukit.ops.basic import add + + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + + # Token embeddings (numpy-based for simplicity) + input_ids_np: np.ndarray = input_ids.to_numpy().astype(np.int32) + embeddings_np = self.embeddings.to_numpy() + token_embeds_np = embeddings_np[input_ids_np.flatten()].reshape(batch_size, seq_len, -1) + token_embeds = from_numpy(token_embeds_np.astype(np.float32)) + + # Position embeddings + positions = np.arange(seq_len, dtype=np.int32) + pos_embeds_np = self.position_embeddings.to_numpy() + pos_embeds_np = pos_embeds_np[positions].reshape(1, seq_len, -1) + pos_embeds = from_numpy(pos_embeds_np.astype(np.float32)) + + # Combine embeddings + hidden_states = add(token_embeds, pos_embeds) + + # Create attention mask if needed + if attention_mask is not None: + # Convert [batch, seq_len] to [batch, 1, 1, seq_len] + mask_np = attention_mask.to_numpy() + extended_mask = mask_np[:, np.newaxis, np.newaxis, :] + extended_mask = (1.0 - extended_mask) * -10000.0 + attention_mask = from_numpy(extended_mask.astype(np.float32)) + + # Apply transformer layers + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + if self.final_norm is not None: + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + +# ============================================================================= +# Style Encoder +# ============================================================================= + + +class StyleEncoder: + """Style encoder for speaker conditioning. + + Converts text features and speaker embedding to style vector. + """ + + def __init__( + self, + convs: list[Conv1d], + norm: LayerNorm | None = None, + output_dim: int = 128, + ): + self.convs = convs + self.norm = norm + self.output_dim = output_dim + + def __call__( + self, + text_features: GPUArray, # [batch, seq_len, hidden_dim] + speaker_embedding: GPUArray | None = None, # [batch, style_dim] + ) -> GPUArray: + """Forward pass. + + Args: + text_features: Text encoder output [batch, seq_len, hidden_dim] + speaker_embedding: Optional speaker style [batch, style_dim] + + Returns: + Style conditioning [batch, style_dim] + """ + # Transpose for conv1d: [batch, hidden_dim, seq_len] + x = text_features.to_numpy().transpose(0, 2, 1) + x = from_numpy(x.astype(np.float32)) + + # Apply convolutions + for conv in self.convs: + x = leaky_relu(conv(x)) + + # Global average pooling: [batch, channels] + x_np = x.to_numpy() + x_pooled = x_np.mean(axis=-1) + + result = from_numpy(x_pooled.astype(np.float32)) + + # Combine with speaker embedding if provided + if speaker_embedding is not None: + from pygpukit.ops.basic import add + + result = add(result, speaker_embedding) + + return result + + +# ============================================================================= +# Decoder +# ============================================================================= + + +class ResBlock1d: + """1D Residual block with dilated convolutions.""" + + def __init__( + self, + convs: list[Conv1d], + ): + self.convs = convs + + def __call__(self, x: GPUArray) -> GPUArray: + """Forward pass with residual connection.""" + from pygpukit.ops.basic import add + + residual = x + for i, conv in enumerate(self.convs): + x = leaky_relu(x) if i > 0 else x + x = conv(x) + return add(x, residual) + + +class Decoder: + """Mel spectrogram decoder. + + Converts text features + style to mel spectrogram. + """ + + def __init__( + self, + input_proj: Linear, + layers: list[ResBlock1d | Conv1d], + output_proj: Linear, + n_mels: int = 80, + ): + self.input_proj = input_proj + self.layers = layers + self.output_proj = output_proj + self.n_mels = n_mels + + def __call__( + self, + text_features: GPUArray, # [batch, seq_len, hidden_dim] + style: GPUArray, # [batch, style_dim] + durations: GPUArray | None = None, # [batch, seq_len] + ) -> GPUArray: + """Forward pass. + + Args: + text_features: Text encoder output + style: Style conditioning + durations: Duration per phoneme (optional, for alignment) + + Returns: + Mel spectrogram [batch, n_mels, mel_len] + """ + # Project input + x = self.input_proj(text_features) + + # Add style conditioning (broadcast over sequence) + # Note: style conditioning is applied after duration expansion + _ = style # Reserved for future use + x_np = x.to_numpy() + + # Simple duration expansion (repeat each frame by duration) + if durations is not None: + dur_np: np.ndarray = durations.to_numpy().astype(np.int32) + expanded: list[np.ndarray] = [] + for b in range(x_np.shape[0]): + frames: list[np.ndarray] = [] + for t in range(x_np.shape[1]): + dur = max(1, int(dur_np[b, t])) + frames.extend([x_np[b, t]] * dur) + expanded.append(np.stack(frames)) + x_np = np.stack(expanded) + + x = from_numpy(x_np.astype(np.float32)) + + # Transpose for conv: [batch, hidden, seq_len] + x = from_numpy(x.to_numpy().transpose(0, 2, 1).astype(np.float32)) + + # Apply decoder layers + for layer in self.layers: + x = layer(x) + + # Transpose back and project to mel + x = from_numpy(x.to_numpy().transpose(0, 2, 1).astype(np.float32)) + mel = self.output_proj(x) + + # Transpose to [batch, n_mels, mel_len] + mel = from_numpy(mel.to_numpy().transpose(0, 2, 1).astype(np.float32)) + + return mel + + +# ============================================================================= +# ISTFTNet Vocoder +# ============================================================================= + + +class ISTFTNet: + """ISTFTNet vocoder for waveform synthesis. + + Converts mel spectrogram to audio waveform using upsampling + and inverse STFT. + """ + + def __init__( + self, + config: KokoroConfig, + ups: list[ConvTranspose1d], + resblocks: list[list[ResBlock1d]], + output_conv: Conv1d, + ): + self.config = config + self.ups = ups + self.resblocks = resblocks + self.output_conv = output_conv + + # ISTFT parameters + self.n_fft = config.gen_istft_n_fft + self.hop_size = config.gen_istft_hop_size + + def __call__(self, mel: GPUArray) -> GPUArray: + """Forward pass. + + Args: + mel: Mel spectrogram [batch, n_mels, mel_len] + + Returns: + Audio waveform [batch, audio_len] + """ + x = mel + + # Upsampling stages + for _i, (up, resblock_group) in enumerate(zip(self.ups, self.resblocks)): + x = leaky_relu(x) + x = up(x) + + # Apply residual blocks and sum + if resblock_group: + xs = None + for resblock in resblock_group: + if xs is None: + xs = resblock(x) + else: + xs_np = xs.to_numpy() + resblock(x).to_numpy() + xs = from_numpy(xs_np.astype(np.float32)) + x = from_numpy((xs.to_numpy() / len(resblock_group)).astype(np.float32)) + + x = leaky_relu(x) + x = self.output_conv(x) + x = tanh(x) + + # ISTFT to convert to waveform + # Output conv produces [batch, n_fft, frames] + # We need to apply ISTFT + x_np = x.to_numpy() + + # Simple overlap-add reconstruction + batch_size = x_np.shape[0] + frames = x_np.shape[2] + audio_len = frames * self.hop_size + self.n_fft - self.hop_size + + audio = np.zeros((batch_size, audio_len), dtype=np.float32) + window = np.hanning(self.n_fft).astype(np.float32) + + for i in range(frames): + start = i * self.hop_size + audio[:, start : start + self.n_fft] += x_np[:, :, i] * window + + # Normalize by window sum + window_sum = np.zeros(audio_len, dtype=np.float32) + for i in range(frames): + start = i * self.hop_size + window_sum[start : start + self.n_fft] += window**2 + window_sum = np.maximum(window_sum, 1e-8) + audio = audio / window_sum + + return from_numpy(audio) + + +# ============================================================================= +# Layer Building Utilities +# ============================================================================= + + +def build_plbert_from_weights( + config: KokoroConfig, + weights: dict[str, GPUArray], + prefix: str = "bert", +) -> PLBERTEncoder: + """Build PLBERT encoder from weight dictionary. + + Args: + config: Model configuration + weights: Dictionary of weight tensors + prefix: Weight name prefix + + Returns: + PLBERTEncoder instance + """ + # Build embeddings + embeddings = weights.get(f"{prefix}.embeddings.word_embeddings.weight") + position_embeddings = weights.get(f"{prefix}.embeddings.position_embeddings.weight") + + if embeddings is None or position_embeddings is None: + raise ValueError(f"Missing embedding weights with prefix '{prefix}'") + + # Build transformer layers + layers = [] + for i in range(config.plbert_num_hidden_layers): + layer_prefix = f"{prefix}.encoder.layer.{i}" + + # Check if layer exists + q_weight = weights.get(f"{layer_prefix}.attention.self.query.weight") + if q_weight is None: + break + + # Self-attention + attention = BertSelfAttention( + query=Linear( + weights[f"{layer_prefix}.attention.self.query.weight"], + weights.get(f"{layer_prefix}.attention.self.query.bias"), + ), + key=Linear( + weights[f"{layer_prefix}.attention.self.key.weight"], + weights.get(f"{layer_prefix}.attention.self.key.bias"), + ), + value=Linear( + weights[f"{layer_prefix}.attention.self.value.weight"], + weights.get(f"{layer_prefix}.attention.self.value.bias"), + ), + num_attention_heads=config.plbert_num_attention_heads, + attention_head_size=config.plbert_hidden_size // config.plbert_num_attention_heads, + ) + + layer = BertLayer( + attention=attention, + attention_output=Linear( + weights[f"{layer_prefix}.attention.output.dense.weight"], + weights.get(f"{layer_prefix}.attention.output.dense.bias"), + ), + attention_norm=LayerNorm( + weights[f"{layer_prefix}.attention.output.LayerNorm.weight"], + weights.get(f"{layer_prefix}.attention.output.LayerNorm.bias"), + ), + intermediate=Linear( + weights[f"{layer_prefix}.intermediate.dense.weight"], + weights.get(f"{layer_prefix}.intermediate.dense.bias"), + ), + output_dense=Linear( + weights[f"{layer_prefix}.output.dense.weight"], + weights.get(f"{layer_prefix}.output.dense.bias"), + ), + output_norm=LayerNorm( + weights[f"{layer_prefix}.output.LayerNorm.weight"], + weights.get(f"{layer_prefix}.output.LayerNorm.bias"), + ), + ) + layers.append(layer) + + return PLBERTEncoder( + config=config, + embeddings=embeddings, + position_embeddings=position_embeddings, + layers=layers, + ) + + +__all__ = [ + # Basic layers + "Linear", + "LayerNorm", + "Conv1d", + "ConvTranspose1d", + "ResBlock1d", + # Activations + "leaky_relu", + "tanh", + # Components + "BertSelfAttention", + "BertLayer", + "PLBERTEncoder", + "StyleEncoder", + "Decoder", + "ISTFTNet", + # Utilities + "build_plbert_from_weights", +] diff --git a/src/pygpukit/tts/kokoro/loader.py b/src/pygpukit/tts/kokoro/loader.py new file mode 100644 index 0000000..8a35567 --- /dev/null +++ b/src/pygpukit/tts/kokoro/loader.py @@ -0,0 +1,334 @@ +"""Model loading utilities for Kokoro TTS. + +Handles loading weights from SafeTensors or PyTorch (.pth) format. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np + +from pygpukit.core.array import GPUArray +from pygpukit.core.factory import from_numpy + +if TYPE_CHECKING: + pass + + +def _download_model(repo_id: str, local_dir: Path | None = None) -> Path: + """Download model from HuggingFace Hub. + + Args: + repo_id: HuggingFace repository ID (e.g., "hexgrad/Kokoro-82M") + local_dir: Local directory to save files (optional) + + Returns: + Path to downloaded model directory + """ + try: + from huggingface_hub import snapshot_download + + return Path( + snapshot_download( + repo_id=repo_id, + local_dir=local_dir, + allow_patterns=["*.json", "*.pth", "*.safetensors", "voices/*.pt"], + ) + ) + except ImportError as err: + raise ImportError( + "huggingface_hub is required to download models. " + "Install with: pip install huggingface_hub" + ) from err + + +def _load_pytorch_weights(path: Path) -> dict[str, np.ndarray]: + """Load weights from PyTorch .pth file. + + Args: + path: Path to .pth file + + Returns: + Dictionary mapping weight names to numpy arrays + """ + try: + import torch + + # Load with CPU mapping + checkpoint = torch.load(path, map_location="cpu", weights_only=False) + + # Handle different checkpoint formats + if isinstance(checkpoint, dict): + if "model" in checkpoint: + state_dict = checkpoint["model"] + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + else: + raise ValueError(f"Unexpected checkpoint format: {type(checkpoint)}") + + # Convert to numpy + weights = {} + for name, tensor in state_dict.items(): + if isinstance(tensor, torch.Tensor): + weights[name] = tensor.numpy() + elif isinstance(tensor, np.ndarray): + weights[name] = tensor + + return weights + except ImportError as err: + raise ImportError( + "PyTorch is required to load .pth files. Install with: pip install torch" + ) from err + + +def _load_safetensors_weights(path: Path) -> dict[str, np.ndarray]: + """Load weights from SafeTensors file. + + Args: + path: Path to .safetensors file + + Returns: + Dictionary mapping weight names to numpy arrays + """ + try: + from safetensors import safe_open + + weights = {} + with safe_open(path, framework="numpy") as f: + for name in f.keys(): + weights[name] = f.get_tensor(name) + + return weights + except ImportError as err: + raise ImportError( + "safetensors is required to load .safetensors files. " + "Install with: pip install safetensors" + ) from err + + +def _convert_to_gpu( + weights: dict[str, np.ndarray], + dtype: str = "bfloat16", +) -> dict[str, GPUArray]: + """Convert numpy weights to GPUArrays. + + Args: + weights: Dictionary of numpy arrays + dtype: Target dtype ("bfloat16" or "float32") + + Returns: + Dictionary of GPUArrays + """ + gpu_weights = {} + for name, array in weights.items(): + # Convert to float32 first if needed + if array.dtype not in (np.float32, np.float16): + array = array.astype(np.float32) + + # Create GPUArray + gpu_array = from_numpy(array) + + # Cast to target dtype if needed + if dtype == "bfloat16" and array.dtype == np.float32: + # Cast on GPU + from pygpukit.ops.tensor import cast_f32_to_bf16 + + gpu_array = cast_f32_to_bf16(gpu_array) + + gpu_weights[name] = gpu_array + + return gpu_weights + + +def load_voice_embedding( + voice_path: Path, +) -> GPUArray: + """Load speaker/voice embedding from .pt file. + + Args: + voice_path: Path to voice .pt file + + Returns: + GPUArray containing voice embedding + """ + try: + import torch + + voice_data = torch.load(voice_path, map_location="cpu", weights_only=False) + + # Voice files contain style embedding tensor + if isinstance(voice_data, torch.Tensor): + embedding = voice_data.numpy() + elif isinstance(voice_data, dict) and "style" in voice_data: + embedding = voice_data["style"].numpy() + else: + raise ValueError(f"Unexpected voice file format: {type(voice_data)}") + + return from_numpy(embedding.astype(np.float32)) + except ImportError as err: + raise ImportError( + "PyTorch is required to load voice files. Install with: pip install torch" + ) from err + + +def list_available_voices(model_path: Path) -> list[str]: + """List available voice embeddings in model directory. + + Args: + model_path: Path to model directory + + Returns: + List of voice names (without .pt extension) + """ + voices_dir = model_path / "voices" + if not voices_dir.exists(): + return [] + + voices = [] + for pt_file in voices_dir.glob("*.pt"): + voices.append(pt_file.stem) + + return sorted(voices) + + +def load_kokoro_weights( + model_path: str | Path, + dtype: str = "bfloat16", + device: str = "cuda", +) -> tuple[dict[str, GPUArray], dict[str, Any]]: + """Load Kokoro model weights and config. + + Args: + model_path: Path to model directory or HuggingFace repo ID + dtype: Weight dtype ("bfloat16" or "float32") + device: Target device (currently only "cuda" supported) + + Returns: + Tuple of (weights dict, config dict) + """ + model_path = Path(model_path) + + # Download from HuggingFace if needed + if not model_path.exists(): + model_path = _download_model(str(model_path)) + + # Find weight file + safetensors_path = model_path / "kokoro-v1_0.safetensors" + pth_path = model_path / "kokoro-v1_0.pth" + + if safetensors_path.exists(): + weights = _load_safetensors_weights(safetensors_path) + elif pth_path.exists(): + weights = _load_pytorch_weights(pth_path) + else: + # Try other common names + for pattern in ["*.safetensors", "*.pth"]: + files = list(model_path.glob(pattern)) + if files: + if pattern == "*.safetensors": + weights = _load_safetensors_weights(files[0]) + else: + weights = _load_pytorch_weights(files[0]) + break + else: + raise FileNotFoundError( + f"No weight file found in {model_path}. " + "Expected kokoro-v1_0.safetensors or kokoro-v1_0.pth" + ) + + # Load config + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + config_dict = json.load(f) + else: + config_dict = {} + + # Convert to GPU + gpu_weights = _convert_to_gpu(weights, dtype=dtype) + + return gpu_weights, config_dict + + +def get_weight_info(weights: dict[str, GPUArray | np.ndarray]) -> dict[str, dict[str, Any]]: + """Get information about model weights. + + Args: + weights: Dictionary of weights + + Returns: + Dictionary with shape and dtype info for each weight + """ + info = {} + for name, w in weights.items(): + if isinstance(w, GPUArray): + info[name] = { + "shape": w.shape, + "dtype": str(w.dtype), + "size_mb": w.nbytes / (1024 * 1024), + } + elif isinstance(w, np.ndarray): + info[name] = { + "shape": w.shape, + "dtype": str(w.dtype), + "size_mb": w.nbytes / (1024 * 1024), + } + return info + + +def print_weight_summary(weights: dict[str, GPUArray | np.ndarray]) -> None: + """Print summary of model weights. + + Args: + weights: Dictionary of weights + """ + info = get_weight_info(weights) + + total_params = 0 + total_size_mb = 0.0 + + print("=" * 60) + print("Kokoro Model Weight Summary") + print("=" * 60) + + # Group by prefix + prefixes: dict[str, list[str]] = {} + for name in sorted(info.keys()): + prefix = name.split(".")[0] if "." in name else name + if prefix not in prefixes: + prefixes[prefix] = [] + prefixes[prefix].append(name) + + for prefix, names in prefixes.items(): + prefix_params = 0 + prefix_size = 0.0 + + for name in names: + shape = info[name]["shape"] + params = 1 + for dim in shape: + params *= dim + prefix_params += params + prefix_size += info[name]["size_mb"] + + print(f"{prefix}: {prefix_params:,} params ({prefix_size:.2f} MB)") + total_params += prefix_params + total_size_mb += prefix_size + + print("-" * 60) + print(f"Total: {total_params:,} params ({total_size_mb:.2f} MB)") + print("=" * 60) + + +__all__ = [ + "load_kokoro_weights", + "load_voice_embedding", + "list_available_voices", + "get_weight_info", + "print_weight_summary", +] diff --git a/src/pygpukit/tts/kokoro/model.py b/src/pygpukit/tts/kokoro/model.py new file mode 100644 index 0000000..4d88e8b --- /dev/null +++ b/src/pygpukit/tts/kokoro/model.py @@ -0,0 +1,388 @@ +"""Kokoro TTS Model. + +High-level API for text-to-speech synthesis using Kokoro-82M. + +Example: + >>> from pygpukit.tts.kokoro import KokoroModel + >>> model = KokoroModel.from_pretrained("hexgrad/Kokoro-82M") + >>> audio = model.synthesize("Hello, world!", voice="af_heart") + >>> audio.to_wav("output.wav") +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from pygpukit.core.array import GPUArray +from pygpukit.core.factory import from_numpy +from pygpukit.ops.audio import AudioBuffer +from pygpukit.tts.kokoro.config import KokoroConfig +from pygpukit.tts.kokoro.loader import ( + list_available_voices, + load_kokoro_weights, + load_voice_embedding, + print_weight_summary, +) +from pygpukit.tts.kokoro.text import KokoroTokenizer, normalize_text, split_sentences + +if TYPE_CHECKING: + from pygpukit.tts.kokoro.layers import Decoder, ISTFTNet, PLBERTEncoder, StyleEncoder + + +@dataclass +class SynthesisResult: + """Result from TTS synthesis. + + Attributes: + audio: Generated audio buffer (24kHz) + text: Original input text + phonemes: Phoneme representation + duration_sec: Audio duration in seconds + """ + + audio: AudioBuffer + text: str + phonemes: str + duration_sec: float + + def to_wav(self, path: str | Path) -> None: + """Save audio to WAV file. + + Args: + path: Output file path + """ + from pygpukit.tts.kokoro.audio import to_wav + + to_wav(self.audio, str(path)) + + def to_numpy(self) -> np.ndarray: + """Get audio as numpy array. + + Returns: + Float32 audio samples + """ + return self.audio.to_numpy() + + +class KokoroModel: + """Kokoro-82M Text-to-Speech Model. + + A StyleTTS2-based TTS model that generates natural-sounding speech + from text input. + + Args: + config: Model configuration + weights: Model weights dictionary + tokenizer: Text tokenizer + voice_embeddings: Dictionary of voice embeddings + + Example: + >>> model = KokoroModel.from_pretrained("hexgrad/Kokoro-82M") + >>> result = model.synthesize("Hello, this is a test.") + >>> result.to_wav("output.wav") + """ + + def __init__( + self, + config: KokoroConfig, + weights: dict[str, GPUArray], + tokenizer: KokoroTokenizer, + voice_embeddings: dict[str, GPUArray] | None = None, + ): + self.config = config + self.weights = weights + self.tokenizer = tokenizer + self.voice_embeddings = voice_embeddings or {} + + # Build model components lazily + self._plbert: PLBERTEncoder | None = None + self._style_encoder: StyleEncoder | None = None + self._decoder: Decoder | None = None + self._vocoder: ISTFTNet | None = None + + # Default voice + self._current_voice: str | None = None + self._current_voice_embedding: GPUArray | None = None + + @classmethod + def from_pretrained( + cls, + model_path: str | Path, + voice: str = "af_heart", + dtype: str = "bfloat16", + load_all_voices: bool = False, + ) -> KokoroModel: + """Load model from pretrained checkpoint. + + Args: + model_path: Path to model directory or HuggingFace repo ID + voice: Default voice to use (e.g., "af_heart") + dtype: Weight dtype ("bfloat16" or "float32") + load_all_voices: Whether to load all voice embeddings + + Returns: + KokoroModel instance + """ + model_path = Path(model_path) + + # Load weights and config + weights, config_dict = load_kokoro_weights(model_path, dtype=dtype) + + # Create config + config = KokoroConfig.from_dict(config_dict) + + # Create tokenizer + tokenizer = KokoroTokenizer.from_config(config, use_misaki=True) + + # Load voice embeddings + voice_embeddings = {} + + if model_path.exists(): + available_voices = list_available_voices(model_path) + + if load_all_voices: + for voice_name in available_voices: + voice_path = model_path / "voices" / f"{voice_name}.pt" + if voice_path.exists(): + voice_embeddings[voice_name] = load_voice_embedding(voice_path) + elif voice in available_voices: + voice_path = model_path / "voices" / f"{voice}.pt" + if voice_path.exists(): + voice_embeddings[voice] = load_voice_embedding(voice_path) + + model = cls( + config=config, + weights=weights, + tokenizer=tokenizer, + voice_embeddings=voice_embeddings, + ) + + # Set default voice + if voice in voice_embeddings: + model.set_voice(voice) + elif voice_embeddings: + model.set_voice(list(voice_embeddings.keys())[0]) + + return model + + def set_voice(self, voice: str) -> None: + """Set the current voice for synthesis. + + Args: + voice: Voice name (e.g., "af_heart", "bf_emma") + """ + if voice not in self.voice_embeddings: + available = list(self.voice_embeddings.keys()) + raise ValueError(f"Voice '{voice}' not loaded. Available: {available}") + + self._current_voice = voice + self._current_voice_embedding = self.voice_embeddings[voice] + + def load_voice(self, voice_path: str | Path) -> str: + """Load a voice embedding from file. + + Args: + voice_path: Path to voice .pt file + + Returns: + Voice name (file stem) + """ + voice_path = Path(voice_path) + voice_name = voice_path.stem + self.voice_embeddings[voice_name] = load_voice_embedding(voice_path) + return voice_name + + @property + def available_voices(self) -> list[str]: + """List of loaded voice names.""" + return list(self.voice_embeddings.keys()) + + @property + def current_voice(self) -> str | None: + """Currently selected voice.""" + return self._current_voice + + def _build_components(self) -> None: + """Build model components from weights (lazy initialization).""" + if self._plbert is not None: + return # Already built + + from pygpukit.tts.kokoro.layers import build_plbert_from_weights + + # Build PLBERT encoder + # Note: Actual weight prefix may vary depending on checkpoint format + # This is a placeholder - actual implementation needs weight inspection + try: + self._plbert = build_plbert_from_weights(self.config, self.weights, prefix="bert") + except (KeyError, ValueError): + # Weights might use different naming + self._plbert = None + + # TODO: Build other components (style encoder, decoder, vocoder) + # These require inspecting actual Kokoro weight structure + + def _forward_simple( + self, + tokens: list[int], + voice_embedding: GPUArray | None = None, + ) -> GPUArray: + """Simple forward pass without full model components. + + This is a placeholder implementation that demonstrates the API. + Full implementation requires matching Kokoro's exact weight structure. + """ + # For now, generate placeholder audio + # Actual implementation would: + # 1. Embed tokens + # 2. Run through PLBERT + # 3. Apply style + # 4. Decode to mel + # 5. Vocode to audio + + # Placeholder: generate silence with some noise + duration_per_token = 0.1 # 100ms per token + total_duration = len(tokens) * duration_per_token + num_samples = int(total_duration * self.config.sample_rate) + + # Generate placeholder audio (sine wave for testing) + t = np.linspace(0, total_duration, num_samples, dtype=np.float32) + frequency = 440.0 # A4 note + audio = 0.1 * np.sin(2 * np.pi * frequency * t) + + return from_numpy(audio) + + def synthesize( + self, + text: str, + voice: str | None = None, + speed: float = 1.0, + normalize: bool = True, + ) -> SynthesisResult: + """Synthesize speech from text. + + Args: + text: Input text to synthesize + voice: Voice to use (None for current voice) + speed: Speech speed multiplier (1.0 = normal) + normalize: Whether to normalize input text + + Returns: + SynthesisResult containing audio and metadata + """ + # Set voice if specified + if voice is not None and voice != self._current_voice: + self.set_voice(voice) + + # Normalize text + if normalize: + text = normalize_text(text) + + # Tokenize + tokenizer_output = self.tokenizer.encode(text) + tokens = tokenizer_output.tokens + phonemes = tokenizer_output.phonemes + + if not tokens: + raise ValueError("No tokens generated from input text") + + # Forward pass + audio_gpu = self._forward_simple(tokens, self._current_voice_embedding) + + # Create AudioBuffer + audio_np = audio_gpu.to_numpy() + audio_buffer = AudioBuffer( + data=audio_gpu, + sample_rate=self.config.sample_rate, + channels=1, + ) + + duration_sec = len(audio_np) / self.config.sample_rate + + return SynthesisResult( + audio=audio_buffer, + text=text, + phonemes=phonemes, + duration_sec=duration_sec, + ) + + def __call__( + self, + text: str, + voice: str | None = None, + **kwargs, + ) -> SynthesisResult: + """Synthesize speech (callable interface). + + Args: + text: Input text + voice: Voice to use + **kwargs: Additional arguments for synthesize() + + Returns: + SynthesisResult + """ + return self.synthesize(text, voice=voice, **kwargs) + + def generate_stream( + self, + text: str, + voice: str | None = None, + chunk_size: int = 4800, # 200ms at 24kHz + ): + """Generate audio in chunks for streaming. + + Args: + text: Input text + voice: Voice to use + chunk_size: Audio chunk size in samples + + Yields: + AudioBuffer chunks + """ + # Split into sentences for chunked generation + sentences = split_sentences(text) + + for sentence in sentences: + result = self.synthesize(sentence, voice=voice) + audio_np = result.audio.to_numpy() + + # Yield in chunks + for i in range(0, len(audio_np), chunk_size): + chunk = audio_np[i : i + chunk_size] + chunk_gpu = from_numpy(chunk) + yield AudioBuffer( + data=chunk_gpu, + sample_rate=self.config.sample_rate, + channels=1, + ) + + def print_info(self) -> None: + """Print model information.""" + print("=" * 60) + print("Kokoro-82M TTS Model") + print("=" * 60) + print(f"Config: {self.config}") + print(f"Voices: {self.available_voices}") + print(f"Current voice: {self._current_voice}") + print(f"Tokenizer: {self.tokenizer}") + print("-" * 60) + print_weight_summary(self.weights) + + def __repr__(self) -> str: + return ( + f"KokoroModel(\n" + f" config={self.config!r},\n" + f" voices={self.available_voices},\n" + f" current_voice={self._current_voice!r}\n" + f")" + ) + + +__all__ = [ + "KokoroModel", + "SynthesisResult", +] diff --git a/src/pygpukit/tts/kokoro/text.py b/src/pygpukit/tts/kokoro/text.py new file mode 100644 index 0000000..1428177 --- /dev/null +++ b/src/pygpukit/tts/kokoro/text.py @@ -0,0 +1,349 @@ +"""Text processing for Kokoro TTS. + +Handles grapheme-to-phoneme (G2P) conversion and tokenization. + +Kokoro uses phoneme-based input with a vocabulary of 178 tokens. +For best quality, use the misaki G2P library. A basic fallback is provided. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pygpukit.tts.kokoro.config import KokoroConfig + + +# Default phoneme vocabulary from Kokoro config.json +# This maps phoneme symbols to token IDs +DEFAULT_VOCAB: dict[str, int] = { + # Padding + "$": 0, + # Punctuation (1-15) + ";": 1, + ":": 2, + ",": 3, + ".": 4, + "!": 5, + "?": 6, + "\u2014": 9, # em-dash + "\u2026": 10, # ellipsis + '"': 11, + "(": 12, + ")": 13, + "\u201c": 14, # left double quote + "\u201d": 15, # right double quote + # Space + " ": 16, + # IPA vowels and consonants (17-177) + # This is a subset - full vocab loaded from config.json + "a": 17, + "b": 18, + "d": 19, + "e": 20, + "f": 21, + "h": 22, + "i": 23, + "j": 24, + "k": 25, + "l": 26, + "m": 27, + "n": 28, + "o": 29, + "p": 30, + "s": 31, + "t": 32, + "u": 33, + "v": 34, + "w": 35, + "z": 36, + # IPA special characters + "\u0251": 69, # open back unrounded vowel + "\u0259": 83, # schwa + "\u014b": 112, # eng (ng) + "\u03b8": 119, # theta (th) + "\u0283": 131, # esh (sh) + "\u02c8": 145, # primary stress + "\u02cc": 146, # secondary stress +} + + +@dataclass +class TokenizerOutput: + """Output from tokenizer. + + Attributes: + tokens: List of token IDs + phonemes: Phoneme string (for debugging) + text: Original text + """ + + tokens: list[int] + phonemes: str + text: str + + def __len__(self) -> int: + return len(self.tokens) + + +class KokoroTokenizer: + """Tokenizer for Kokoro TTS model. + + Converts text to phoneme token sequences for the model. + + Args: + vocab: Phoneme to token ID mapping (from config.json) + lang: Language code for G2P ('a' for American English, etc.) + use_misaki: Whether to use misaki G2P library (recommended) + + Example: + >>> tokenizer = KokoroTokenizer.from_config(config) + >>> output = tokenizer.encode("Hello, world!") + >>> print(output.tokens) # [22, 83, 26, 29, 33, 3, 16, 35, ...] + """ + + def __init__( + self, + vocab: dict[str, int] | None = None, + lang: str = "a", + use_misaki: bool = True, + ): + self.vocab = vocab or DEFAULT_VOCAB + self.lang = lang + self.use_misaki = use_misaki + self._misaki_pipeline = None + + # Create reverse mapping for decoding + self.id_to_phoneme = {v: k for k, v in self.vocab.items()} + + # Padding token + self.pad_token = "$" + self.pad_id = self.vocab.get(self.pad_token, 0) + + # Try to initialize misaki if requested + if use_misaki: + self._init_misaki() + + def _init_misaki(self) -> bool: + """Initialize misaki G2P pipeline.""" + try: + from misaki import en + + # Create pipeline for the specified language + if self.lang in ("a", "en-us"): + self._misaki_pipeline = en.G2P(trf=False) # Fast mode + else: + # Fallback to basic + self._misaki_pipeline = None + return self._misaki_pipeline is not None + except ImportError: + self._misaki_pipeline = None + return False + + @classmethod + def from_config(cls, config: KokoroConfig, **kwargs) -> KokoroTokenizer: + """Create tokenizer from KokoroConfig. + + Args: + config: KokoroConfig with vocab mapping + **kwargs: Additional arguments for tokenizer + + Returns: + KokoroTokenizer instance + """ + vocab = config.vocab if config.vocab else DEFAULT_VOCAB + return cls(vocab=vocab, **kwargs) + + def _text_to_phonemes_basic(self, text: str) -> str: + """Basic text to phoneme conversion (fallback). + + This is a simple character-level conversion for testing. + For production, use misaki G2P. + """ + # Normalize text + text = text.lower() + + # Simple replacements for common patterns + replacements = [ + (r"th", "\u03b8"), # theta + (r"sh", "\u0283"), # esh + (r"ng", "\u014b"), # eng + (r"ch", "t\u0283"), # ch + ] + + for pattern, replacement in replacements: + text = re.sub(pattern, replacement, text) + + return text + + def _text_to_phonemes_misaki(self, text: str) -> str: + """Convert text to phonemes using misaki G2P.""" + if self._misaki_pipeline is None: + return self._text_to_phonemes_basic(text) + + try: + # misaki returns a generator of (grapheme, phoneme) tuples + result = self._misaki_pipeline(text) + + # Collect all phonemes from the generator + phoneme_parts = [] + for item in result: + if isinstance(item, tuple) and len(item) >= 2: + # (grapheme, phoneme) tuple + phoneme_parts.append(str(item[1]) if item[1] else "") + elif isinstance(item, str): + phoneme_parts.append(item) + + # Join phonemes with space separator + phonemes = " ".join(p for p in phoneme_parts if p) + return phonemes if phonemes else self._text_to_phonemes_basic(text) + except Exception: + # Fallback on error + return self._text_to_phonemes_basic(text) + + def text_to_phonemes(self, text: str) -> str: + """Convert text to phoneme string. + + Args: + text: Input text + + Returns: + Phoneme string + """ + if self.use_misaki and self._misaki_pipeline is not None: + return self._text_to_phonemes_misaki(text) + return self._text_to_phonemes_basic(text) + + def phonemes_to_tokens(self, phonemes: str) -> list[int]: + """Convert phoneme string to token IDs. + + Args: + phonemes: Phoneme string + + Returns: + List of token IDs + """ + tokens = [] + i = 0 + while i < len(phonemes): + # Try to match longest sequence first + matched = False + + # Check for multi-character phonemes (up to 3 chars) + for length in [3, 2, 1]: + if i + length <= len(phonemes): + substr = phonemes[i : i + length] + if substr in self.vocab: + tokens.append(self.vocab[substr]) + i += length + matched = True + break + + if not matched: + # Unknown character - skip or use padding + i += 1 + + return tokens + + def encode(self, text: str) -> TokenizerOutput: + """Encode text to token sequence. + + Args: + text: Input text + + Returns: + TokenizerOutput with tokens, phonemes, and original text + """ + phonemes = self.text_to_phonemes(text) + tokens = self.phonemes_to_tokens(phonemes) + + return TokenizerOutput( + tokens=tokens, + phonemes=phonemes, + text=text, + ) + + def decode(self, tokens: list[int]) -> str: + """Decode token sequence to phoneme string. + + Args: + tokens: List of token IDs + + Returns: + Phoneme string + """ + phonemes = [] + for token_id in tokens: + if token_id in self.id_to_phoneme: + phonemes.append(self.id_to_phoneme[token_id]) + return "".join(phonemes) + + def __call__(self, text: str) -> TokenizerOutput: + """Encode text (callable interface).""" + return self.encode(text) + + def __repr__(self) -> str: + return ( + f"KokoroTokenizer(vocab_size={len(self.vocab)}, " + f"lang='{self.lang}', misaki={self._misaki_pipeline is not None})" + ) + + +def normalize_text(text: str) -> str: + """Normalize text for TTS processing. + + - Converts to lowercase where appropriate + - Normalizes whitespace + - Expands common abbreviations + + Args: + text: Input text + + Returns: + Normalized text + """ + # Normalize whitespace + text = re.sub(r"\s+", " ", text).strip() + + # Expand common abbreviations + abbreviations = { + "Mr.": "Mister", + "Mrs.": "Missus", + "Dr.": "Doctor", + "Jr.": "Junior", + "Sr.": "Senior", + "vs.": "versus", + "etc.": "etcetera", + "e.g.": "for example", + "i.e.": "that is", + } + + for abbr, expansion in abbreviations.items(): + text = text.replace(abbr, expansion) + + return text + + +def split_sentences(text: str) -> list[str]: + """Split text into sentences for chunked processing. + + Args: + text: Input text + + Returns: + List of sentences + """ + # Split on sentence-ending punctuation + sentences = re.split(r"(?<=[.!?])\s+", text) + return [s.strip() for s in sentences if s.strip()] + + +__all__ = [ + "KokoroTokenizer", + "TokenizerOutput", + "normalize_text", + "split_sentences", + "DEFAULT_VOCAB", +]