From fba7ae8f0c1f2641ef279bf76017b05d24571793 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Wed, 10 Sep 2025 18:26:05 +0800 Subject: [PATCH 1/8] start point of ace step, maybe need an extra day to implement all features and improve code style --- diffsynth_engine/configs/__init__.py | 2 + diffsynth_engine/configs/pipeline.py | 10 + diffsynth_engine/models/ace_step/ace_dit.py | 470 +++++++++++++ .../models/ace_step/ace_lyric_encoder.py | 283 ++++++++ diffsynth_engine/models/ace_step/vae/dcae.py | 625 ++++++++++++++++++ .../models/ace_step/vae/hifi_gan.py | 623 +++++++++++++++++ .../models/ace_step/vae/music_dcae.py | 89 +++ diffsynth_engine/pipelines/ace_step.py | 29 + 8 files changed, 2131 insertions(+) create mode 100644 diffsynth_engine/models/ace_step/ace_dit.py create mode 100644 diffsynth_engine/models/ace_step/ace_lyric_encoder.py create mode 100644 diffsynth_engine/models/ace_step/vae/dcae.py create mode 100644 diffsynth_engine/models/ace_step/vae/hifi_gan.py create mode 100644 diffsynth_engine/models/ace_step/vae/music_dcae.py create mode 100644 diffsynth_engine/pipelines/ace_step.py diff --git a/diffsynth_engine/configs/__init__.py b/diffsynth_engine/configs/__init__.py index a8783869..bb81790d 100644 --- a/diffsynth_engine/configs/__init__.py +++ b/diffsynth_engine/configs/__init__.py @@ -10,6 +10,7 @@ WanSpeech2VideoPipelineConfig, QwenImagePipelineConfig, HunyuanPipelineConfig, + ACEStepPipelineConfig, BaseStateDicts, SDStateDicts, SDXLStateDicts, @@ -32,6 +33,7 @@ "WanSpeech2VideoPipelineConfig", "QwenImagePipelineConfig", "HunyuanPipelineConfig", + "ACEStepPipelineConfig", "BaseStateDicts", "SDStateDicts", "SDXLStateDicts", diff --git a/diffsynth_engine/configs/pipeline.py b/diffsynth_engine/configs/pipeline.py index cc67ade1..49a667bb 100644 --- a/diffsynth_engine/configs/pipeline.py +++ b/diffsynth_engine/configs/pipeline.py @@ -266,6 +266,16 @@ class HunyuanPipelineConfig(BaseConfig): image_encoder_dtype: torch.dtype = torch.float16 +@dataclass +class ACEStepPipelineConfig(BaseConfig): + model_path: str | os.PathLike | List[str | os.PathLike] + model_dtype: torch.dtype = torch.float16 + vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None + vae_dtype: torch.dtype = torch.float16 + image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None + image_encoder_dtype: torch.dtype = torch.float16 + + @dataclass class BaseStateDicts: pass diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py new file mode 100644 index 00000000..0fb46811 --- /dev/null +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -0,0 +1,470 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Any, Dict, List, Optional +from einops import rearrange + +from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel +from diffsynth_engine.models.basic.timestep import TimestepEmbeddings +from diffsynth_engine.models.basic.attention import attention +from diffsynth_engine.models.basic.transformer_helper import RMSNorm +from diffsynth_engine.models.wan.wan_dit import rope_apply, modulate +from diffsynth_engine.models.ace_step.ace_lyric_encoder import ConformerEncoder + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"): + super().__init__() + self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float() + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + + def forward(self, x: torch.Tensor): + seq_len = x.shape[1] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len) + + return ( + self.cos_cached[:seq_len].to(x.dtype), + self.sin_cached[:seq_len].to(x.dtype), + ) + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + attn_kwargs: Optional[Dict[str, Any]] = None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // num_heads + self.q = nn.Linear(dim, dim, device=device, dtype=dtype) + self.k = nn.Linear(dim, dim, device=device, dtype=dtype) + self.v = nn.Linear(dim, dim, device=device, dtype=dtype) + self.o = nn.Linear(dim, dim, device=device, dtype=dtype) + self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + + def forward(self, x, freqs, attn_mask): # x: (b, s, d), attn_mask: (b, s) + q, k, v = self.q(x), self.k(x), self.v(x) + num_heads = q.shape[2] // self.head_dim + attn_mask = attn_mask[:, :, None, None] + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + q = F.relu(rope_apply(q, freqs) * attn_mask) + k = F.relu(rope_apply(k, freqs) * attn_mask) + v = v * attn_mask + q = rearrange(q, "b s n d -> b n d s") + k = rearrange(k, "b s n d -> b n s d") + v = rearrange(v, "b s n d -> b n d s") + v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1.0) # b n (d+1) s + x = torch.matmul(torch.matmul(v, k), q) # inner: b n (d+1) d + x = x[:, :, :-1] / (x[:, :, -1:] + 1e-15) # b n d s + x = rearrange(x, "b n d s -> b s (n d)") + return self.o(x) + + +class CrossAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + attn_kwargs: Optional[Dict[str, Any]] = None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim, device=device, dtype=dtype) + self.k = nn.Linear(dim, dim, device=device, dtype=dtype) + self.v = nn.Linear(dim, dim, device=device, dtype=dtype) + self.o = nn.Linear(dim, dim, device=device, dtype=dtype) + self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + + def forward( + self, + x: torch.Tensor, + ctx: torch.Tensor, + freqs: torch.Tensor, + freqs_ctx: torch.Tensor, + attn_mask: torch.Tensor, + attn_mask_ctx: torch.Tensor, + ) -> torch.Tensor: + q, k, v = self.q(x), self.k(ctx), self.v(ctx) + num_heads = q.shape[2] // self.head_dim + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + q = rope_apply(q, freqs) + k = rope_apply(k, freqs_ctx) + mask = attn_mask[:, :, None] * attn_mask_ctx[:, None, :] + mask = torch.where(mask == 1, 0.0, -torch.inf)[:, None] + mask = mask.expand(-1, num_heads, -1, -1).to(q.dtype) + x = attention(q, k, v, attn_mask=mask, **self.attn_kwargs) + x = rearrange(x, "b s n d -> b s (n d)") + return self.o(x) + + +class ConvLayer(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + kernel_size: int = 3, + groups: int = 1, + use_bias: bool = False, + act: str | None = None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.conv = nn.Conv1d( + in_dim, + out_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + bias=use_bias, + device=device, + dtype=dtype, + ) + self.act = nn.SiLU(inplace=True) if act else None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + if self.act: + x = self.act(x) + return x + + +class GLUMBConv(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + + self.glu_act = nn.SiLU(inplace=False) + self.inverted_conv = ConvLayer( + in_features, + hidden_features * 2, + kernel_size=1, + use_bias=True, + act="silu", + device=device, + dtype=dtype, + ) + self.depth_conv = ConvLayer( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + groups=hidden_features * 2, + use_bias=True, + act="silu", + device=device, + dtype=dtype, + ) + self.point_conv = ConvLayer( + hidden_features, + in_features, + kernel_size=1, + use_bias=False, + act=None, + device=device, + dtype=dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.transpose(1, 2) + x = self.inverted_conv(x) + x = self.depth_conv(x) + x, gate = torch.chunk(x, 2, dim=1) + x *= self.glu_act(gate) + x = self.point_conv(x) + x = x.transpose(1, 2) + return x + + +class DiTBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + attn_kwargs: Optional[Dict[str, Any]] = None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.norm2 = RMSNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.attn = SelfAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype) + self.cross_attn = CrossAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype) + self.ff = GLUMBConv(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype) + self.scale_shift_table = nn.Parameter(torch.randn(6, dim, device=device, dtype=dtype) / dim**0.5) + + def forward(self, x, context, t_mod, freqs, freqs_ctx, attn_mask, attn_mask_ctx): + # msa: multi-head self-attention mlp: multi-layer perceptron + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.squeeze(1) for t in (self.scale_shift_table[None] + t_mod).chunk(6, dim=1) + ] + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x += gate_msa * self.attn(input_x, freqs, attn_mask) + x += self.cross_attn(x, context, freqs, freqs_ctx, attn_mask, attn_mask_ctx) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x += gate_mlp * self.ff(input_x) + return x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + patch_size=(16, 1), + in_channels=8, + embed_dim=1152, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.early_conv_layers = nn.Sequential( + nn.Conv2d( + in_channels, + in_channels * 256, + kernel_size=patch_size, + stride=patch_size, + device=device, + dtype=dtype, + ), + nn.GroupNorm(num_groups=32, num_channels=in_channels * 256, eps=1e-6, device=device, dtype=dtype), + nn.Conv2d( + in_channels * 256, + embed_dim, + kernel_size=1, + device=device, + dtype=dtype, + ), + ) + + def forward(self, latent): + # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size + latent = self.early_conv_layers(latent) + return rearrange(latent, "b c h w -> b (h w) c") + + +class FinalLayer(nn.Module): + """Similar to `Head` in Wan2.1.""" + + def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256): + super().__init__() + self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class ACEStepDiTStateDictConverter(StateDictConverter): + def convert(self, state_dict): # TODO + return state_dict + + +class ACEStepDiT(PreTrainedModel): + converter = ACEStepDiTStateDictConverter() + _supports_parallelization = True + + def __init__( + self, + num_layers: int = 24, + head_dim: int = 64, + num_heads: int = 20, + mlp_ratio: float = 2.5, + out_channels: int = 8, + max_position: int = 32768, + rope_theta: float = 1000000.0, + speaker_embedding_dim: int = 512, + text_embedding_dim: int = 768, + ssl_latent_dims: List[int] = [1024, 768], + lyric_encoder_vocab_size: int = 6693, + lyric_hidden_size: int = 1024, + patch_size: List[int] = [16, 1], + max_height: int = 16, + max_width: int = 32768, + attn_kwargs: Optional[Dict[str, Any]] = None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.patch_size = patch_size + self.rotary_emb = Qwen2RotaryEmbedding( + dim=head_dim, + max_position_embeddings=max_position, + base=rope_theta, + ) + + inner_dim = num_heads * head_dim + self.transformer_blocks = nn.ModuleList( + [ + DiTBlock( + dim=inner_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + attn_kwargs=attn_kwargs, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + self.time_embedder = TimestepEmbeddings(dim_in=256, dim_out=inner_dim, device=device, dtype=dtype) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(inner_dim, 6 * inner_dim, device=device, dtype=dtype)) + + self.speaker_embedder = nn.Linear(speaker_embedding_dim, inner_dim, device=device, dtype=dtype) + self.genre_embedder = nn.Linear(text_embedding_dim, inner_dim, device=device, dtype=dtype) + self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, device=device, dtype=dtype) + self.lyric_encoder = ConformerEncoder(input_size=lyric_hidden_size) + self.lyric_proj = nn.Linear(lyric_hidden_size, inner_dim, device=device, dtype=dtype) + + projector_dim = 2 * inner_dim + self.projectors = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(inner_dim, projector_dim, device=device, dtype=dtype), + nn.SiLU(), + nn.Linear(projector_dim, projector_dim, device=device, dtype=dtype), + nn.SiLU(), + nn.Linear(projector_dim, ssl_dim, device=device, dtype=dtype), + ) + for ssl_dim in ssl_latent_dims + ] + ) + + self.proj_in = PatchEmbed( + height=max_height, + width=max_width, + patch_size=patch_size, + embed_dim=inner_dim, + device=device, + dtype=dtype, + ) + + self.final_layer = FinalLayer(inner_dim, patch_size=patch_size, out_channels=out_channels) + + def forward_lyric_encoder(self, context_lyric: torch.LongTensor, attn_mask_lyric: torch.LongTensor): + lyric_embs = self.lyric_embs(context_lyric) # N x T x D + prompt_prenet_out = self.lyric_encoder(lyric_embs, attn_mask_lyric) + prompt_prenet_out = self.lyric_proj(prompt_prenet_out) + return prompt_prenet_out + + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + return rearrange( + x, + "b (h w) (x y c) -> b c (h x) (w y)", + h=1, + x=self.patch_size[0], + y=self.patch_size[1], + ) + + def encode( + self, + context_prompt: torch.Tensor, + context_speaker: torch.FloatTensor, + context_lyric: torch.LongTensor, + attn_mask_prompt: torch.LongTensor, + attn_mask_lyric: torch.LongTensor, + ): + context_speaker = self.speaker_embedder(context_speaker).unsqueeze(1) + context_prompt = self.genre_embedder(context_prompt) + context_lyric = self.forward_lyric_encoder(context_lyric=context_lyric, attn_mask_lyric=attn_mask_lyric) + context = torch.cat([context_speaker, context_prompt, context_lyric], dim=1) + + attn_mask_speaker = torch.ones(context_speaker.shape[0], 1, device=context_speaker.device) + context_mask = torch.cat([attn_mask_speaker, attn_mask_prompt, attn_mask_lyric], dim=1) + return context, context_mask + + def decode( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor], + context: torch.Tensor, + attn_mask: torch.Tensor, + attn_mask_ctx: torch.Tensor, + ): + t = self.time_embedder(timestep) + t_mod = self.t_block(t) + + x = self.proj_in(x) + + freqs = self.rotary_emb(x) + freqs_ctx = self.rotary_emb(context) + + for block in self.transformer_blocks: + x = block( + x=x, + context=context, + t_mod=t_mod, + freqs=freqs, + freqs_ctx=freqs_ctx, + attn_mask=attn_mask, + attn_mask_ctx=attn_mask_ctx, + ) + + x = self.final_layer(x, t) + x = self.unpatchify(x) + return x + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context_prompt: torch.Tensor, + context_speaker: torch.FloatTensor, + context_lyric: torch.LongTensor, + attn_mask: torch.Tensor, + attn_mask_prompt: torch.LongTensor, + attn_mask_lyric: torch.LongTensor, + ): + context, attn_mask_ctx = self.encode( + context_prompt=context_prompt, + attn_mask_prompt=attn_mask_prompt, + context_speaker=context_speaker, + context_lyric=context_lyric, + attn_mask_lyric=attn_mask_lyric, + ) + output_length = x.shape[-1] + + output = self.decode( + x=x, + timestep=timestep, + context=context, + attn_mask=attn_mask, + attn_mask_ctx=attn_mask_ctx, + ) + + if output_length > output.shape[-1]: + output = F.pad(output, (0, output_length - output.shape[-1], 0, 0), "constant", 0) + elif output_length < output.shape[-1]: + output = output[:, :, :, :output_length] + return output diff --git a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py new file mode 100644 index 00000000..275c6a3c --- /dev/null +++ b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py @@ -0,0 +1,283 @@ +from typing import Tuple +import math +from einops import rearrange +import torch +import torch.nn as nn + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, idim: int, hidden_units: int): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(idim, hidden_units) + self.activation = nn.SiLU() + self.w_2 = nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + return self.w_2(self.activation(self.w_1(xs))) + + +class RelPositionMultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with relative position encoding. + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + """ + + def __init__(self, n_head: int, n_feat: int): + self.n_head = n_head + self.d_k = n_feat // n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + self.pos_bias_u = nn.Parameter(torch.Tensor(n_head, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(n_head, self.d_k)) + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + + Returns: + torch.Tensor: Output tensor. + + """ + b, h, t, d = x.shape + zero_pad = torch.zeros((b, h, t, 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) # b h t (d+1) + x_padded = x_padded.view(b, h, d + 1, t) + # only keep the positions from 0 to time2 + x = x_padded[:, :, 1:].view_as(x)[..., : d // 2 + 1] + return x + + def forward(self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor) -> torch.Tensor: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + x (torch.Tensor): Query tensor (#batch, time1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q = self.linear_q(x) + k = self.linear_k(x) + v = self.linear_v(x) + p = self.linear_pos(pos_emb) + q = rearrange(q, "b t (h d) -> b h t d", h=self.n_head) + k = rearrange(k, "b t (h d) -> b h d t", h=self.n_head) + v = rearrange(v, "b t (h d) -> b h t d", h=self.n_head) + p = rearrange(p, "b t (h d) -> b h d t", h=self.n_head) + + q_with_bias_u = q + self.pos_bias_u[None, :, None, :] + q_with_bias_v = q + self.pos_bias_v[None, :, None, :] + matrix_ac = torch.matmul(q_with_bias_u, k) + matrix_bd = torch.matmul(q_with_bias_v, p) + if matrix_ac.shape != matrix_bd.shape: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + mask = mask.eq(0)[:, None, :, : scores.shape[-1]] + scores = scores.masked_fill(mask, -float("inf")) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + x = torch.matmul(attn, v) + x = rearrange(x, "b h t d -> b t (h d)") + return self.linear_out(x) + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + """ + + def __init__( + self, + size: int, + num_heads: int, + linear_units: int, + ): + super().__init__() + self.norm_mha = nn.LayerNorm(size) # for the MHA module + self.self_attn = RelPositionMultiHeadedAttention(num_heads, size) + self.norm_ff = nn.LayerNorm(size) # for the FNN module + self.feed_forward = PositionwiseFeedForward(size, linear_units) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + """ + x += self.self_attn(self.norm_mha(x), mask, pos_emb) + x += self.feed_forward(self.norm_ff(x)) + return x, mask + + +class EspnetRelPositionalEncoding(nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, max_len: int = 5000): + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.size(1)) + return x, pos_emb + + def position_encoding(self, size: int) -> torch.Tensor: + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size, + ] + return pos_emb + + +class LinearEmbed(nn.Module): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + + """ + + def __init__(self, idim: int, odim: int): + super().__init__() + self.out = nn.Sequential( + nn.Linear(idim, odim), + nn.LayerNorm(odim, eps=1e-5), + ) + self.pos_enc = (EspnetRelPositionalEncoding(odim),) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x) + return x, pos_emb + + +class ConformerEncoder(nn.Module): + def __init__( + self, + input_size: int, + output_size: int = 1024, + attention_heads: int = 16, + linear_units: int = 4096, + num_blocks: int = 6, + ): + super().__init__() + self.embed = LinearEmbed( + input_size, + output_size, + ) + self.encoders = nn.ModuleList( + [ + ConformerEncoderLayer( + output_size, + attention_heads, + linear_units, + ) + for _ in range(num_blocks) + ] + ) + self.after_norm = nn.LayerNorm(output_size) + + def forward( + self, + xs: torch.Tensor, + pad_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + masks = pad_mask.to(torch.bool)[None] # (B, 1, T) + xs, pos_emb = self.embed(xs) + for layer in self.encoders: + xs, masks = layer(xs, masks, pos_emb) + xs = self.after_norm(xs) + return xs diff --git a/diffsynth_engine/models/ace_step/vae/dcae.py b/diffsynth_engine/models/ace_step/vae/dcae.py new file mode 100644 index 00000000..822c3e57 --- /dev/null +++ b/diffsynth_engine/models/ace_step/vae/dcae.py @@ -0,0 +1,625 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Union + + +class RMSNorm(nn.RMSNorm): + def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False): + super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) + if elementwise_affine: + self.bias = nn.Parameter(torch.empty(dim)) if bias else None + + def forward(self, x): + x = super().forward(x) + if self.elementwise_affine: + if self.bias is not None: + x += self.bias, + return x + + +def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5): + if norm_type == "batch_norm": + return nn.BatchNorm2d(num_features) + elif norm_type == "group_norm": + return nn.GroupNorm(num_groups, num_features) + elif norm_type == "layer_norm": + return nn.LayerNorm(num_features) + elif norm_type == "rms_norm": + return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True) + else: + raise ValueError(f"Unknown normalization type: {norm_type}") + + +def get_activation(activation_type): + if activation_type == "relu": + return nn.ReLU() + elif activation_type == "relu6": + return nn.ReLU6() + elif activation_type == "silu": + return nn.SiLU() + elif activation_type == "leaky_relu": + return nn.LeakyReLU(0.2) + else: + raise ValueError(f"Unknown activation type: {activation_type}") + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + norm_type: str = "batch_norm", + act_fn: str = "relu6", + ) -> None: + super().__init__() + + self.norm_type = norm_type + self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity() + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False) + self.norm = get_normalization(norm_type, out_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.norm_type == "rms_norm": + # move channel to the last dimension so we apply RMSnorm across channel dimension + hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + residual + +class SanaMultiscaleAttentionProjection(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + kernel_size: int, + ) -> None: + super().__init__() + + channels = 3 * in_channels + self.proj_in = nn.Conv2d( + channels, + channels, + kernel_size, + padding=kernel_size // 2, + groups=channels, + bias=False, + ) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + +class SanaMultiscaleLinearAttention(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_attention_heads: int = None, + attention_head_dim: int = 8, + mult: float = 1.0, + norm_type: str = "batch_norm", + kernel_sizes: tuple = (5,), + eps: float = 1e-15, + residual_connection: bool = False, + ): + super().__init__() + + self.eps = eps + self.attention_head_dim = attention_head_dim + self.norm_type = norm_type + self.residual_connection = residual_connection + + num_attention_heads = ( + int(in_channels // attention_head_dim * mult) + if num_attention_heads is None + else num_attention_heads + ) + inner_dim = num_attention_heads * attention_head_dim + + self.to_q = nn.Linear(in_channels, inner_dim, bias=False) + self.to_k = nn.Linear(in_channels, inner_dim, bias=False) + self.to_v = nn.Linear(in_channels, inner_dim, bias=False) + + self.to_qkv_multiscale = nn.ModuleList() + for kernel_size in kernel_sizes: + self.to_qkv_multiscale.append( + SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) + ) + + self.nonlinearity = nn.ReLU() + self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) + self.norm_out = get_normalization(norm_type, out_channels) + + def apply_linear_attention(self, query, key, value): + value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) + scores = torch.matmul(value, key.transpose(-1, -2)) + hidden_states = torch.matmul(scores, query) + + hidden_states = hidden_states.to(dtype=torch.float32) + hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps) + return hidden_states + + def apply_quadratic_attention(self, query, key, value): + scores = torch.matmul(key.transpose(-1, -2), query) + scores = scores.to(dtype=torch.float32) + scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps) + hidden_states = torch.matmul(value, scores.to(value.dtype)) + return hidden_states + + def forward(self, hidden_states): + height, width = hidden_states.shape[-2:] + if height * width > self.attention_head_dim: + use_linear_attention = True + else: + use_linear_attention = False + + residual = hidden_states + + batch_size, _, height, width = list(hidden_states.size()) + original_dtype = hidden_states.dtype + + hidden_states = hidden_states.movedim(1, -1) + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + hidden_states = torch.cat([query, key, value], dim=3) + hidden_states = hidden_states.movedim(-1, 1) + + multi_scale_qkv = [hidden_states] + for block in self.to_qkv_multiscale: + multi_scale_qkv.append(block(hidden_states)) + + hidden_states = torch.cat(multi_scale_qkv, dim=1) + + if use_linear_attention: + # for linear attention upcast hidden_states to float32 + hidden_states = hidden_states.to(dtype=torch.float32) + + hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width) + + query, key, value = hidden_states.chunk(3, dim=2) + query = self.nonlinearity(query) + key = self.nonlinearity(key) + + if use_linear_attention: + hidden_states = self.apply_linear_attention(query, key, value) + hidden_states = hidden_states.to(dtype=original_dtype) + else: + hidden_states = self.apply_quadratic_attention(query, key, value) + + hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width)) + hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.norm_type == "rms_norm": + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + else: + hidden_states = self.norm_out(hidden_states) + + if self.residual_connection: + hidden_states = hidden_states + residual + + return hidden_states + + +class EfficientViTBlock(nn.Module): + def __init__( + self, + in_channels: int, + mult: float = 1.0, + attention_head_dim: int = 32, + qkv_multiscales: tuple = (5,), + norm_type: str = "batch_norm", + ) -> None: + super().__init__() + + self.attn = SanaMultiscaleLinearAttention( + in_channels=in_channels, + out_channels=in_channels, + mult=mult, + attention_head_dim=attention_head_dim, + norm_type=norm_type, + kernel_sizes=qkv_multiscales, + residual_connection=True, + ) + + self.conv_out = GLUMBConv( + in_channels=in_channels, + out_channels=in_channels, + norm_type="rms_norm", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(x) + x = self.conv_out(x) + return x + + +class GLUMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + expand_ratio: float = 4, + norm_type: str = None, + residual_connection: bool = True, + ) -> None: + super().__init__() + + hidden_channels = int(expand_ratio * in_channels) + self.norm_type = norm_type + self.residual_connection = residual_connection + + self.nonlinearity = nn.SiLU() + self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) + self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2) + self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) + + self.norm = None + if norm_type == "rms_norm": + self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual_connection: + residual = hidden_states + + hidden_states = self.conv_inverted(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv_depth(hidden_states) + hidden_states, gate = torch.chunk(hidden_states, 2, dim=1) + hidden_states = hidden_states * self.nonlinearity(gate) + + hidden_states = self.conv_point(hidden_states) + + if self.norm_type == "rms_norm": + # move channel to the last dimension so we apply RMSnorm across channel dimension + hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.residual_connection: + hidden_states = hidden_states + residual + + return hidden_states + + +def get_block( + block_type: str, + in_channels: int, + out_channels: int, + attention_head_dim: int, + norm_type: str, + act_fn: str, + qkv_mutliscales: tuple = (), +): + if block_type == "ResBlock": + block = ResBlock(in_channels, out_channels, norm_type, act_fn) + elif block_type == "EfficientViTBlock": + block = EfficientViTBlock( + in_channels, + attention_head_dim=attention_head_dim, + norm_type=norm_type, + qkv_multiscales=qkv_mutliscales + ) + else: + raise ValueError(f"Block with {block_type=} is not supported.") + + return block + + +class DCDownBlock2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None: + super().__init__() + + self.downsample = downsample + self.factor = 2 + self.stride = 1 if downsample else 2 + self.group_size = in_channels * self.factor**2 // out_channels + self.shortcut = shortcut + + out_ratio = self.factor**2 + if downsample: + assert out_channels % out_ratio == 0 + out_channels = out_channels // out_ratio + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=self.stride, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self.conv(hidden_states) + if self.downsample: + x = F.pixel_unshuffle(x, self.factor) + + if self.shortcut: + y = F.pixel_unshuffle(hidden_states, self.factor) + y = y.unflatten(1, (-1, self.group_size)) + y = y.mean(dim=2) + hidden_states = x + y + else: + hidden_states = x + + return hidden_states + + +class DCUpBlock2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + interpolate: bool = False, + shortcut: bool = True, + interpolation_mode: str = "nearest", + ) -> None: + super().__init__() + + self.interpolate = interpolate + self.interpolation_mode = interpolation_mode + self.shortcut = shortcut + self.factor = 2 + self.repeats = out_channels * self.factor**2 // in_channels + + out_ratio = self.factor**2 + if not interpolate: + out_channels = out_channels * out_ratio + + self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.interpolate: + x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode) + x = self.conv(x) + else: + x = self.conv(hidden_states) + x = F.pixel_shuffle(x, self.factor) + + if self.shortcut: + y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) + y = F.pixel_shuffle(y, self.factor) + hidden_states = x + y + else: + hidden_states = x + + return hidden_states + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + latent_channels: int, + attention_head_dim: int = 32, + block_type: str | tuple = "ResBlock", + block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024), + layers_per_block: tuple = (2, 2, 2, 2, 2, 2), + qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)), + downsample_block_type: str = "pixel_unshuffle", + out_shortcut: bool = True, + ): + super().__init__() + + num_blocks = len(block_out_channels) + if isinstance(block_type, str): + block_type = (block_type,) * num_blocks + + if layers_per_block[0] > 0: + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], + kernel_size=3, + stride=1, + padding=1, + ) + else: + self.conv_in = DCDownBlock2d( + in_channels=in_channels, + out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], + downsample=downsample_block_type == "pixel_unshuffle", + shortcut=False, + ) + + down_blocks = [] + for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)): + down_block_list = [] + + for _ in range(num_layers): + block = get_block( + block_type[i], + out_channel, + out_channel, + attention_head_dim=attention_head_dim, + norm_type="rms_norm", + act_fn="silu", + qkv_mutliscales=qkv_multiscales[i], + ) + down_block_list.append(block) + + if i < num_blocks - 1 and num_layers > 0: + downsample_block = DCDownBlock2d( + in_channels=out_channel, + out_channels=block_out_channels[i + 1], + downsample=downsample_block_type == "pixel_unshuffle", + shortcut=True, + ) + down_block_list.append(downsample_block) + + down_blocks.append(nn.Sequential(*down_block_list)) + + self.down_blocks = nn.ModuleList(down_blocks) + + self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1) + + self.out_shortcut = out_shortcut + if out_shortcut: + self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + if self.out_shortcut: + x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) + x = x.mean(dim=2) + hidden_states = self.conv_out(hidden_states) + x + else: + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class Decoder(nn.Module): + def __init__( + self, + in_channels: int, + latent_channels: int, + attention_head_dim: int = 32, + block_type: str | tuple = "ResBlock", + block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024), + layers_per_block: tuple = (2, 2, 2, 2, 2, 2), + qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)), + norm_type: str | tuple = "rms_norm", + act_fn: str | tuple = "silu", + upsample_block_type: str = "pixel_shuffle", + in_shortcut: bool = True, + ): + super().__init__() + + num_blocks = len(block_out_channels) + + if isinstance(block_type, str): + block_type = (block_type,) * num_blocks + if isinstance(norm_type, str): + norm_type = (norm_type,) * num_blocks + if isinstance(act_fn, str): + act_fn = (act_fn,) * num_blocks + + self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1) + + self.in_shortcut = in_shortcut + if in_shortcut: + self.in_shortcut_repeats = block_out_channels[-1] // latent_channels + + up_blocks = [] + for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))): + up_block_list = [] + + if i < num_blocks - 1 and num_layers > 0: + upsample_block = DCUpBlock2d( + block_out_channels[i + 1], + out_channel, + interpolate=upsample_block_type == "interpolate", + shortcut=True, + ) + up_block_list.append(upsample_block) + + for _ in range(num_layers): + block = get_block( + block_type[i], + out_channel, + out_channel, + attention_head_dim=attention_head_dim, + norm_type=norm_type[i], + act_fn=act_fn[i], + qkv_mutliscales=qkv_multiscales[i], + ) + up_block_list.append(block) + + up_blocks.insert(0, nn.Sequential(*up_block_list)) + + self.up_blocks = nn.ModuleList(up_blocks) + + channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1] + + self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True) + self.conv_act = nn.ReLU() + self.conv_out = None + + if layers_per_block[0] > 0: + self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1) + else: + self.conv_out = DCUpBlock2d( + channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.in_shortcut: + x = hidden_states.repeat_interleave( + self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats + ) + hidden_states = self.conv_in(hidden_states) + x + else: + hidden_states = self.conv_in(hidden_states) + + for up_block in reversed(self.up_blocks): + hidden_states = up_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class DCAE(nn.Module): + def __init__( + self, + in_channels: int = 2, + latent_channels: int = 8, + attention_head_dim: int = 32, + encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"], + decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"], + encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024), + decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024), + encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3), + decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3), + encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)), + decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)), + upsample_block_type: str = "interpolate", + downsample_block_type: str = "Conv", + decoder_norm_types: Union[str, Tuple[str]] = "rms_norm", + decoder_act_fns: Union[str, Tuple[str]] = "silu", + ) -> None: + super().__init__() + + self.encoder = Encoder( + in_channels=in_channels, + latent_channels=latent_channels, + attention_head_dim=attention_head_dim, + block_type=encoder_block_types, + block_out_channels=encoder_block_out_channels, + layers_per_block=encoder_layers_per_block, + qkv_multiscales=encoder_qkv_multiscales, + downsample_block_type=downsample_block_type, + ) + + self.decoder = Decoder( + in_channels=in_channels, + latent_channels=latent_channels, + attention_head_dim=attention_head_dim, + block_type=decoder_block_types, + block_out_channels=decoder_block_out_channels, + layers_per_block=decoder_layers_per_block, + qkv_multiscales=decoder_qkv_multiscales, + norm_type=decoder_norm_types, + act_fn=decoder_act_fns, + upsample_block_type=upsample_block_type, + ) + + + def encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) diff --git a/diffsynth_engine/models/ace_step/vae/hifi_gan.py b/diffsynth_engine/models/ace_step/vae/hifi_gan.py new file mode 100644 index 00000000..a1936a2d --- /dev/null +++ b/diffsynth_engine/models/ace_step/vae/hifi_gan.py @@ -0,0 +1,623 @@ +from typing import Callable, List, Tuple +import math +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.parametrizations import weight_norm +from torchaudio.transforms import MelScale + + +class LinearSpectrogram(nn.Module): + def __init__( + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", + ): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.mode = mode + + self.register_buffer("window", torch.hann_window(win_length)) + + def forward(self, y: torch.Tensor) -> torch.Tensor: + if y.ndim == 3: + y = y.squeeze(1) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + (self.win_length - self.hop_length) // 2, + (self.win_length - self.hop_length + 1) // 2, + ), + mode="reflect", + ).squeeze(1) + dtype = y.dtype + spec = torch.stft( + y.float(), + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + if self.mode == "pow2_sqrt": + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = spec.to(dtype) + return spec + + +class LogMelSpectrogram(nn.Module): + def __init__( + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max or sample_rate // 2 + + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + self.mel_scale = MelScale( + self.n_mels, + self.sample_rate, + self.f_min, + self.f_max, + self.n_fft // 2 + 1, + "slaney", + "slaney", + ) + + def compress(self, x: torch.Tensor) -> torch.Tensor: + return torch.log(torch.clamp(x, min=1e-5)) + + def decompress(self, x: torch.Tensor) -> torch.Tensor: + return torch.exp(x) + + def forward(self, x: torch.Tensor, return_linear: bool = False) -> torch.Tensor: + linear = self.spectrogram(x) + x = self.mel_scale(linear) + x = self.compress(x) + # print(x.shape) + if return_linear: + return x, self.compress(linear) + + return x + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ # noqa: E501 + + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob, 3):0.3f}" + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ # noqa: E501 + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None] * x + self.bias[:, None] + return x + + +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + kernel_size (int): Kernel size for depthwise conv. Default: 7. + dilation (int): Dilation for depthwise conv. Default: 1. + """ # noqa: E501 + + def __init__( + self, + dim: int, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6, + mlp_ratio: float = 4.0, + kernel_size: int = 7, + dilation: int = 1, + ): + super().__init__() + + self.dwconv = nn.Conv1d( + dim, + dim, + kernel_size=kernel_size, + padding=int(dilation * (kernel_size - 1) / 2), + groups=dim, + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim)) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, apply_residual: bool = True): + input = x + + x = self.dwconv(x) + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma * x + + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) + x = self.drop_path(x) + + if apply_residual: + x = input + x + + return x + + +class ParallelConvNeXtBlock(nn.Module): + def __init__(self, kernel_sizes: List[int], *args, **kwargs): + super().__init__() + self.blocks = nn.ModuleList( + [ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs) for kernel_size in kernel_sizes] + ) + + def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor: + return torch.stack( + [block(x, apply_residual=False) for block in self.blocks] + [x], + dim=1, + ).sum(dim=1) + + +class ConvNeXtEncoder(nn.Module): + def __init__( + self, + input_channels=3, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + kernel_sizes: Tuple[int] = (7,), + ): + super().__init__() + assert len(depths) == len(dims) + + self.channel_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv1d( + input_channels, + dims[0], + kernel_size=7, + padding=3, + padding_mode="replicate", + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.channel_layers.append(stem) + + for i in range(len(depths) - 1): + mid_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), + ) + self.channel_layers.append(mid_layer) + + block_fn = ( + partial(ConvNeXtBlock, kernel_size=kernel_sizes[0]) + if len(kernel_sizes) == 1 + else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes) + ) + + self.stages = nn.ModuleList() + drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + cur = 0 + for i in range(len(depths)): + stage = nn.Sequential( + *[ + block_fn( + dim=dims[i], + drop_path=drop_path_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward( + self, + x: torch.torch.Tensor, + ) -> torch.torch.Tensor: + for channel_layer, stage in zip(self.channel_layers, self.stages): + x = channel_layer(x) + x = stage(x) + + return self.norm(x) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.silu(x) + xt = c1(xt) + xt = F.silu(xt) + xt = c2(xt) + x = xt + x + return x + + +class HiFiGANGenerator(nn.Module): + def __init__( + self, + *, + hop_length: int = 512, + upsample_rates: Tuple[int] = (8, 8, 2, 2, 2), + upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2), + resblock_kernel_sizes: Tuple[int] = (3, 7, 11), + resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + num_mels: int = 128, + upsample_initial_channel: int = 512, + use_template: bool = True, + pre_conv_kernel_size: int = 7, + post_conv_kernel_size: int = 7, + post_activation: Callable = partial(nn.SiLU, inplace=True), + ): + super().__init__() + + assert math.prod(upsample_rates) == hop_length, f"hop_length must be {math.prod(upsample_rates)}" + + self.conv_pre = weight_norm( + nn.Conv1d( + num_mels, + upsample_initial_channel, + pre_conv_kernel_size, + 1, + padding=get_padding(pre_conv_kernel_size), + ) + ) + + self.num_upsamples = len(upsample_rates) + self.num_kernels = len(resblock_kernel_sizes) + + self.noise_convs = nn.ModuleList() + self.use_template = use_template + self.ups = nn.ModuleList() + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + c_cur = upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + if not use_template: + continue + + if i + 1 < len(upsample_rates): + stride_f0 = np.prod(upsample_rates[i + 1 :]) + self.noise_convs.append( + nn.Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + self.resblocks.append(ResBlock1(ch, k, d)) + + self.activation_post = post_activation() + self.conv_post = weight_norm( + nn.Conv1d( + ch, + 1, + post_conv_kernel_size, + 1, + padding=get_padding(post_conv_kernel_size), + ) + ) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x, template=None): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.silu(x, inplace=True) + x = self.ups[i](x) + + if self.use_template: + x = x + self.noise_convs[i](template) + + xs = None + + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + + x = xs / self.num_kernels + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + +class ADaMoSHiFiGANV1(nn.Module): + def __init__( + self, + input_channels: int = 128, + depths: List[int] = [3, 3, 9, 3], + dims: List[int] = [128, 256, 384, 512], + drop_path_rate: float = 0.0, + kernel_sizes: Tuple[int] = (7,), + upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2), + upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4), + resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13), + resblock_dilation_sizes: Tuple[Tuple[int]] = ( + (1, 3, 5), + (1, 3, 5), + (1, 3, 5), + (1, 3, 5), + ), + num_mels: int = 512, + upsample_initial_channel: int = 1024, + use_template: bool = False, + pre_conv_kernel_size: int = 13, + post_conv_kernel_size: int = 13, + sampling_rate: int = 44100, + n_fft: int = 2048, + win_length: int = 2048, + hop_length: int = 512, + f_min: int = 40, + f_max: int = 16000, + n_mels: int = 128, + ): + super().__init__() + + self.backbone = ConvNeXtEncoder( + input_channels=input_channels, + depths=depths, + dims=dims, + drop_path_rate=drop_path_rate, + kernel_sizes=kernel_sizes, + ) + + self.head = HiFiGANGenerator( + hop_length=hop_length, + upsample_rates=upsample_rates, + upsample_kernel_sizes=upsample_kernel_sizes, + resblock_kernel_sizes=resblock_kernel_sizes, + resblock_dilation_sizes=resblock_dilation_sizes, + num_mels=num_mels, + upsample_initial_channel=upsample_initial_channel, + use_template=use_template, + pre_conv_kernel_size=pre_conv_kernel_size, + post_conv_kernel_size=post_conv_kernel_size, + ) + self.sampling_rate = sampling_rate + self.mel_transform = LogMelSpectrogram( + sample_rate=sampling_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + ) + self.eval() + + @torch.no_grad() + def decode(self, mel): + y = self.backbone(mel) + y = self.head(y) + return y + + @torch.no_grad() + def encode(self, x): + return self.mel_transform(x) + + def forward(self, mel): + y = self.backbone(mel) + y = self.head(y) + return y diff --git a/diffsynth_engine/models/ace_step/vae/music_dcae.py b/diffsynth_engine/models/ace_step/vae/music_dcae.py new file mode 100644 index 00000000..8264e2e8 --- /dev/null +++ b/diffsynth_engine/models/ace_step/vae/music_dcae.py @@ -0,0 +1,89 @@ +import torch +from torchvision import transforms +import torchaudio + +from diffsynth_engine.models.base import PreTrainedModel + +from .hifi_gan import ADaMoSHiFiGANV1 +from .dcae import DCAE # TODO: rewrite above 2 files + + +class MusicDCAE(PreTrainedModel): + def __init__( + self, + dace: DCAE, + vocoder: ADaMoSHiFiGANV1, + source_sample_rate: int = 48000, + ): + super(MusicDCAE, self).__init__() + + self.dcae = dace + self.vocoder = vocoder + self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) + self.transform = transforms.Compose([transforms.Normalize(0.5, 0.5)]) + self.min_mel_value = -11.0 + self.max_mel_value = 3.0 + self.time_dimention_multiple = 8 + self.scale_factor = 0.1786 + self.shift_factor = -1.9091 + + def load_audio(self, audio_path): + audio, sr = torchaudio.load(audio_path) + if audio.shape[0] == 1: + audio = audio.repeat(2, 1) + return audio, sr + + def forward_mel(self, audios): + mels = [] + for i in range(len(audios)): + image = self.vocoder.mel_transform(audios[i]) + mels.append(image) + mels = torch.stack(mels) + return mels + + @torch.no_grad() + def encode(self, audios: torch.Tensor, sr: int): + audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0]) + audio_lengths = audio_lengths.to(audios.device) + + # audios: N x 2 x T, 48kHz + resampler = torchaudio.transforms.Resample(sr, 44100) + resampler = resampler.to(device=audios.device, dtype=audios.dtype) + audio = resampler(audios) + + max_audio_len = audio.shape[-1] + if max_audio_len % (8 * 512) != 0: + audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512))) + + mels = self.forward_mel(audio) + mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value) + mels = self.transform(mels) + latents = [] + for mel in mels: + latent = self.dcae.encoder(mel.unsqueeze(0)) + latents.append(latent) + latents = torch.cat(latents, dim=0) + latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() + latents = (latents - self.shift_factor) * self.scale_factor + return latents, latent_lengths + + @torch.no_grad() + def decode(self, latents: torch.Tensor, sr: int): + latents = latents / self.scale_factor + self.shift_factor + + pred_wavs = [] + for latent in latents: + mels = self.dcae.decoder(latent.unsqueeze(0)) + mels = mels * 0.5 + 0.5 + mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value + + # decode waveform for each channels to reduce vram footprint + wav_ch1 = self.vocoder.decode(mels[:, 0, :, :]).squeeze(1).cpu() + wav_ch2 = self.vocoder.decode(mels[:, 1, :, :]).squeeze(1).cpu() + wav = torch.cat([wav_ch1, wav_ch2], dim=0) + + resampler = torchaudio.transforms.Resample(44100, sr) + wav = resampler(wav.cpu().float()) + pred_wavs.append(wav) + + return sr, pred_wavs diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py new file mode 100644 index 00000000..8d6b7b76 --- /dev/null +++ b/diffsynth_engine/pipelines/ace_step.py @@ -0,0 +1,29 @@ +import math + +import numpy as np +import torch +import torch.nn.functional as F +import torch.distributed as dist +from torchvision.transforms.functional import pil_to_tensor +from typing import Callable, List, Optional +from tqdm import tqdm +from PIL import Image + +from diffsynth_engine.configs import ACEStepPipelineConfig +from diffsynth_engine.models.ace_step.ace_dit import ACEStepDiT +from diffsynth_engine.models.ace_step.vae.music_dcae import MusicDCAE +from diffsynth_engine.pipelines import BasePipeline +from diffsynth_engine.utils import logging + + +logger = logging.get_logger(__name__) + + +class ACEStepMusicPipeline(BasePipeline): + def __init__( + self, + config: ACEStepPipelineConfig, + dit: ACEStepDiT, + vae: MusicDCAE, + ): + pass \ No newline at end of file From 2e4556926774168a48ae21aebbe0e2a2c7214158 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Thu, 11 Sep 2025 18:30:43 +0800 Subject: [PATCH 2/8] almost finish, lefts are: test / example / save video / edit pipelines / config --- diffsynth_engine/conf/models/ace/dit.json | 33 + diffsynth_engine/conf/models/ace/t5.json | 25 + .../conf/models/ace/vae/dcae.json | 66 + .../conf/models/ace/vae/vocoder.json | 77 + diffsynth_engine/configs/__init__.py | 2 + diffsynth_engine/configs/pipeline.py | 26 +- diffsynth_engine/models/ace_step/ace_dit.py | 54 +- .../models/ace_step/ace_text_encoder.py | 30 + .../ace_step/lyric_tokenizer/lang_segment.py | 1071 ++ .../lyric_tokenizer/lyric_tokenizer.py | 690 + .../ace_step/lyric_tokenizer/num2str.py | 325 + .../ace_step/lyric_tokenizer/vocab.json | 15535 ++++++++++++++++ .../ace_step/lyric_tokenizer/zh_num2words.py | 1113 ++ diffsynth_engine/models/ace_step/vae/dcae.py | 30 +- .../models/ace_step/vae/hifi_gan.py | 29 +- .../models/ace_step/vae/music_dcae.py | 8 +- diffsynth_engine/pipelines/ace_step.py | 484 +- diffsynth_engine/utils/constants.py | 5 + 18 files changed, 19574 insertions(+), 29 deletions(-) create mode 100644 diffsynth_engine/conf/models/ace/dit.json create mode 100644 diffsynth_engine/conf/models/ace/t5.json create mode 100644 diffsynth_engine/conf/models/ace/vae/dcae.json create mode 100644 diffsynth_engine/conf/models/ace/vae/vocoder.json create mode 100644 diffsynth_engine/models/ace_step/ace_text_encoder.py create mode 100644 diffsynth_engine/models/ace_step/lyric_tokenizer/lang_segment.py create mode 100644 diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py create mode 100644 diffsynth_engine/models/ace_step/lyric_tokenizer/num2str.py create mode 100644 diffsynth_engine/models/ace_step/lyric_tokenizer/vocab.json create mode 100644 diffsynth_engine/models/ace_step/lyric_tokenizer/zh_num2words.py diff --git a/diffsynth_engine/conf/models/ace/dit.json b/diffsynth_engine/conf/models/ace/dit.json new file mode 100644 index 00000000..bc409f49 --- /dev/null +++ b/diffsynth_engine/conf/models/ace/dit.json @@ -0,0 +1,33 @@ +{ + "attention_head_dim": 128, + "in_channels": 8, + "inner_dim": 2560, + "lyric_encoder_vocab_size": 6693, + "lyric_hidden_size": 1024, + "max_height": 16, + "max_position": 32768, + "max_width": 32768, + "mlp_ratio": 2.5, + "num_attention_heads": 20, + "num_layers": 24, + "out_channels": 8, + "patch_size": [ + 16, + 1 + ], + "rope_theta": 1000000.0, + "speaker_embedding_dim": 512, + "ssl_encoder_depths": [ + 8, + 8 + ], + "ssl_latent_dims": [ + 1024, + 768 + ], + "ssl_names": [ + "mert", + "m-hubert" + ], + "text_embedding_dim": 768 +} \ No newline at end of file diff --git a/diffsynth_engine/conf/models/ace/t5.json b/diffsynth_engine/conf/models/ace/t5.json new file mode 100644 index 00000000..0d6b2148 --- /dev/null +++ b/diffsynth_engine/conf/models/ace/t5.json @@ -0,0 +1,25 @@ +{ + "classifier_dropout": 0.0, + "d_ff": 2048, + "d_kv": 64, + "d_model": 768, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "num_decoder_layers": 12, + "num_heads": 12, + "num_layers": 12, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "scalable_attention": true, + "tie_word_embeddings": false, + "vocab_size": 256384 +} \ No newline at end of file diff --git a/diffsynth_engine/conf/models/ace/vae/dcae.json b/diffsynth_engine/conf/models/ace/vae/dcae.json new file mode 100644 index 00000000..a514a707 --- /dev/null +++ b/diffsynth_engine/conf/models/ace/vae/dcae.json @@ -0,0 +1,66 @@ +{ + "attention_head_dim": 32, + "decoder_act_fns": "silu", + "decoder_block_out_channels": [ + 128, + 256, + 512, + 1024 + ], + "decoder_block_types": [ + "ResBlock", + "ResBlock", + "ResBlock", + "EfficientViTBlock" + ], + "decoder_layers_per_block": [ + 3, + 3, + 3, + 3 + ], + "decoder_norm_types": "rms_norm", + "decoder_qkv_multiscales": [ + [], + [], + [ + 5 + ], + [ + 5 + ] + ], + "downsample_block_type": "Conv", + "encoder_block_out_channels": [ + 128, + 256, + 512, + 1024 + ], + "encoder_block_types": [ + "ResBlock", + "ResBlock", + "ResBlock", + "EfficientViTBlock" + ], + "encoder_layers_per_block": [ + 2, + 2, + 3, + 3 + ], + "encoder_qkv_multiscales": [ + [], + [], + [ + 5 + ], + [ + 5 + ] + ], + "in_channels": 2, + "latent_channels": 8, + "scaling_factor": 0.41407, + "upsample_block_type": "interpolate" +} \ No newline at end of file diff --git a/diffsynth_engine/conf/models/ace/vae/vocoder.json b/diffsynth_engine/conf/models/ace/vae/vocoder.json new file mode 100644 index 00000000..771e3446 --- /dev/null +++ b/diffsynth_engine/conf/models/ace/vae/vocoder.json @@ -0,0 +1,77 @@ +{ + "depths": [ + 3, + 3, + 9, + 3 + ], + "dims": [ + 128, + 256, + 384, + 512 + ], + "drop_path_rate": 0.0, + "f_max": 16000, + "f_min": 40, + "hop_length": 512, + "input_channels": 128, + "kernel_sizes": [ + 7 + ], + "n_fft": 2048, + "n_mels": 128, + "num_mels": 512, + "post_conv_kernel_size": 13, + "pre_conv_kernel_size": 13, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11, + 13 + ], + "sampling_rate": 44100, + "upsample_initial_channel": 1024, + "upsample_kernel_sizes": [ + 8, + 8, + 4, + 4, + 4, + 4, + 4 + ], + "upsample_rates": [ + 4, + 4, + 2, + 2, + 2, + 2, + 2 + ], + "use_template": false, + "win_length": 2048 +} \ No newline at end of file diff --git a/diffsynth_engine/configs/__init__.py b/diffsynth_engine/configs/__init__.py index bb81790d..388f158f 100644 --- a/diffsynth_engine/configs/__init__.py +++ b/diffsynth_engine/configs/__init__.py @@ -17,6 +17,7 @@ FluxStateDicts, WanStateDicts, WanS2VStateDicts, + ACEStateDicts, QwenImageStateDicts, ) from .controlnet import ControlType, ControlNetParams @@ -40,6 +41,7 @@ "FluxStateDicts", "WanStateDicts", "WanS2VStateDicts", + "ACEStateDicts", "QwenImageStateDicts", "ControlType", "ControlNetParams", diff --git a/diffsynth_engine/configs/pipeline.py b/diffsynth_engine/configs/pipeline.py index 49a667bb..e4d86c19 100644 --- a/diffsynth_engine/configs/pipeline.py +++ b/diffsynth_engine/configs/pipeline.py @@ -267,13 +267,19 @@ class HunyuanPipelineConfig(BaseConfig): @dataclass -class ACEStepPipelineConfig(BaseConfig): +class ACEStepPipelineConfig(AttentionConfig, OptimizationConfig, BaseConfig): model_path: str | os.PathLike | List[str | os.PathLike] - model_dtype: torch.dtype = torch.float16 - vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None - vae_dtype: torch.dtype = torch.float16 - image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None - image_encoder_dtype: torch.dtype = torch.float16 + model_dtype: torch.dtype = torch.bfloat16 + dcae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None + vocoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None + vae_dtype: torch.dtype = torch.bfloat16 + t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None + t5_dtype: torch.dtype = torch.bfloat16 + + # default params set by model type + shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor + cfg_scale: Optional[float | Tuple[float, float]] = field(default=None, init=False) # default CFG scale + num_inference_steps: Optional[int] = field(default=None, init=False) # default inference steps @dataclass @@ -320,6 +326,14 @@ class WanS2VStateDicts: audio_encoder: Dict[str, torch.Tensor] +@dataclass +class ACEStateDicts: + model: Dict[str, torch.Tensor] + t5: Dict[str, torch.Tensor] + dcae: Dict[str, torch.Tensor] + vocoder: Dict[str, torch.Tensor] + + @dataclass class QwenImageStateDicts: model: Dict[str, torch.Tensor] diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 0fb46811..bbd89ff9 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -1,7 +1,9 @@ +from typing import Any, Dict, List, Optional +import json + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Any, Dict, List, Optional from einops import rearrange from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel @@ -10,6 +12,7 @@ from diffsynth_engine.models.basic.transformer_helper import RMSNorm from diffsynth_engine.models.wan.wan_dit import rope_apply, modulate from diffsynth_engine.models.ace_step.ace_lyric_encoder import ConformerEncoder +from diffsynth_engine.utils.constants import ACE_DIT_CONFIG_FILE class Qwen2RotaryEmbedding(nn.Module): @@ -286,7 +289,7 @@ def forward(self, x, t): class ACEStepDiTStateDictConverter(StateDictConverter): - def convert(self, state_dict): # TODO + def convert(self, state_dict): # TODO return state_dict @@ -390,17 +393,25 @@ def unpatchify(self, x: torch.Tensor) -> torch.Tensor: def encode( self, context_prompt: torch.Tensor, - context_speaker: torch.FloatTensor, context_lyric: torch.LongTensor, attn_mask_prompt: torch.LongTensor, attn_mask_lyric: torch.LongTensor, ): - context_speaker = self.speaker_embedder(context_speaker).unsqueeze(1) + b = context_prompt.shape[0] + device, dtype = context_prompt.device, context_prompt.dtype + context_speaker = torch.zeros( + b, + 1, + self.speaker_embedder.in_features, + device=device, + dtype=dtype, + ) + context_speaker = self.speaker_embedder(context_speaker) context_prompt = self.genre_embedder(context_prompt) context_lyric = self.forward_lyric_encoder(context_lyric=context_lyric, attn_mask_lyric=attn_mask_lyric) context = torch.cat([context_speaker, context_prompt, context_lyric], dim=1) - attn_mask_speaker = torch.ones(context_speaker.shape[0], 1, device=context_speaker.device) + attn_mask_speaker = torch.ones(b, 1, device=device) context_mask = torch.cat([attn_mask_speaker, attn_mask_prompt, attn_mask_lyric], dim=1) return context, context_mask @@ -410,7 +421,7 @@ def decode( timestep: Optional[torch.Tensor], context: torch.Tensor, attn_mask: torch.Tensor, - attn_mask_ctx: torch.Tensor, + attn_mask_ctx: torch.Tensor, ): t = self.time_embedder(timestep) t_mod = self.t_block(t) @@ -440,7 +451,6 @@ def forward( x: torch.Tensor, timestep: torch.Tensor, context_prompt: torch.Tensor, - context_speaker: torch.FloatTensor, context_lyric: torch.LongTensor, attn_mask: torch.Tensor, attn_mask_prompt: torch.LongTensor, @@ -448,9 +458,8 @@ def forward( ): context, attn_mask_ctx = self.encode( context_prompt=context_prompt, - attn_mask_prompt=attn_mask_prompt, - context_speaker=context_speaker, context_lyric=context_lyric, + attn_mask_prompt=attn_mask_prompt, attn_mask_lyric=attn_mask_lyric, ) output_length = x.shape[-1] @@ -468,3 +477,30 @@ def forward( elif output_length < output.shape[-1]: output = output[:, :, :, :output_length] return output + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + config: Dict[str, Any], + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + attn_kwargs: Optional[Dict[str, Any]] = None, + assign: bool = True, + ): + model = cls(**config, device="meta", dtype=dtype, attn_kwargs=attn_kwargs) + model = model.requires_grad_(False) + model.load_state_dict(state_dict, assign=assign) + model.to(device=device, dtype=dtype, non_blocking=True) + return model + + @staticmethod + def get_model_config() -> dict: + config_file = ACE_DIT_CONFIG_FILE + with open(config_file, "r") as f: + config = json.load(f) + return config + + def compile_repeated_blocks(self, *args, **kwargs): + for block in self.transformer_blocks: + block.compile(*args, **kwargs) diff --git a/diffsynth_engine/models/ace_step/ace_text_encoder.py b/diffsynth_engine/models/ace_step/ace_text_encoder.py new file mode 100644 index 00000000..3135479f --- /dev/null +++ b/diffsynth_engine/models/ace_step/ace_text_encoder.py @@ -0,0 +1,30 @@ +from typing import Dict, Any +import json + +import torch + +from diffsynth_engine.models.wan.wan_text_encoder import WanTextEncoder +from diffsynth_engine.utils.constants import ACE_TEXT_ENCODER_CONFIG_FILE + + +class ACETextEncoder(WanTextEncoder): + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + config: Dict[str, Any], + device: str = "cuda:0", + dtype: torch.dtype = torch.float32, + ) -> "ACETextEncoder": + model = cls(**config, device="meta", dtype=dtype) + model.requires_grad_(False) + model.load_state_dict(state_dict, assign=True) + model.to(device=device, dtype=dtype, non_blocking=True) + return model + + @staticmethod + def get_model_config() -> dict: + config_file = ACE_TEXT_ENCODER_CONFIG_FILE + with open(config_file, "r") as f: + config = json.load(f) + return config \ No newline at end of file diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/lang_segment.py b/diffsynth_engine/models/ace_step/lyric_tokenizer/lang_segment.py new file mode 100644 index 00000000..999f14e6 --- /dev/null +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/lang_segment.py @@ -0,0 +1,1071 @@ +""" +This file bundles language identification functions. + +Modifications (fork): Copyright (c) 2021, Adrien Barbaresi. + +Original code: Copyright (c) 2011 Marco Lui . +Based on research by Marco Lui and Tim Baldwin. + +See LICENSE file for more info. +https://github.com/adbar/py3langid + +Projects: +https://github.com/juntaosun/LangSegment + +LICENSE: +py3langid - Language Identifier +BSD 3-Clause License + +Modifications (fork): Copyright (c) 2021, Adrien Barbaresi. + +Original code: Copyright (c) 2011 Marco Lui . +Based on research by Marco Lui and Tim Baldwin. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are +permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ``AS IS'' AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import re +import numpy as np +from collections import Counter +from collections import defaultdict + +# import langid +# import py3langid as langid +# pip install py3langid==0.2.2 + +# 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。 +# langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag. +# For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows: +from py3langid.langid import LanguageIdentifier, MODEL_FILE + +from .num2str import num2str + +# ----------------------------------- +# 更新日志:新版本分词更加精准。 +# Changelog: The new version of the word segmentation is more accurate. +# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。 +# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다. +# ----------------------------------- + + +# Word segmentation function: +# automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages, +# making it more suitable for TTS processing. +# This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects. +# This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing. + +# =========================================================================================================== +# 分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。 +# このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。 +# =========================================================================================================== +# (1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」 +# (2)手动分词:“あなたの名前は佐々木ですか?ですか?” +# この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。 +# =========================================================================================================== + +# =========================================================================================================== +# 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다. +# 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다. +# =========================================================================================================== +# (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다." +# (2) 수동 참여: "이름이 Saki입니까? ?" +# 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다. +# =========================================================================================================== + +# =========================================================================================================== +# 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。 +# 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。 +# =========================================================================================================== +# (1)自动分词:“韩语中的오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型” +# (2)手动分词:“你的名字叫佐々木?吗?” +# 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。 +# =========================================================================================================== + + +# 手动分词标签规范:<语言标签>文本内容 +# 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용 +# Manual word segmentation tag specification: text content +# 手動分詞タグ仕様:<言語タグ>テキスト内容 +# =========================================================================================================== +# For manual word segmentation, labels need to appear in pairs, such as: +# 如需手动分词,标签需要成对出现,例如:“佐々木” 或者 “佐々木” +# 错误示范:“你的名字叫佐々木。” 此句子中出现的单个标签将被忽略,不会处理。 +# Error demonstration: "Your name is 佐々木。" Single tags that appear in this sentence will be ignored and will not be processed. +# =========================================================================================================== + + +# =========================================================================================================== +# 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML) +# 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome! +# 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。 +# Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese. +# =========================================================================================================== +# 中文实现:Chinese implementation: +# 【SSML】=中文大写数字读法(单字) +# 【SSML】=数字转成中文电话号码大写汉字(单字) +# 【SSML】=按金额发音。 +# 【SSML】=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。 +# =========================================================================================================== +class LangSSML: + + def __init__(self): + # 纯数字 + self._zh_numerals_number = { + "0": "零", + "1": "一", + "2": "二", + "3": "三", + "4": "四", + "5": "五", + "6": "六", + "7": "七", + "8": "八", + "9": "九", + } + + # 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日” + # Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day" + def _format_chinese_data(self, date_str: str): + # 处理日期格式 + if date_str is None or date_str.strip() == "": + return "" + date_str = re.sub(r"[\/\._|年|月]", "-", date_str) + date_str = re.sub(r"日", r"", date_str) + date_arrs = date_str.split(" ") + if len(date_arrs) == 1 and ":" in date_arrs[0]: + time_str = date_arrs[0] + date_arrs = [] + else: + time_str = date_arrs[1] if len(date_arrs) >= 2 else "" + + def nonZero(num, cn, func=None): + if func is not None: + num = func(num) + return f"{num}{cn}" if num is not None and num != "" and num != "0" else "" + + f_number = self.to_chinese_number + f_currency = self.to_chinese_currency + # year, month, day + year_month_day = "" + if len(date_arrs) > 0: + year, month, day = "", "", "" + parts = date_arrs[0].split("-") + if len(parts) == 3: # 格式为 YYYY-MM-DD + year, month, day = parts + elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM + if len(parts[0]) == 4: # 年-月 + year, month = parts + else: + month, day = parts # 月-日 + elif len(parts[0]) > 0: # 仅有月-日或年 + if len(parts[0]) == 4: + year = parts[0] + else: + day = parts[0] + year, month, day = ( + nonZero(year, "年", f_number), + nonZero(month, "月", f_currency), + nonZero(day, "日", f_currency), + ) + year_month_day = re.sub(r"([年|月|日])+", r"\1", f"{year}{month}{day}") + # hours, minutes, seconds + time_str = re.sub(r"[\/\.\-:_]", ":", time_str) + time_arrs = time_str.split(":") + hours, minutes, seconds = "", "", "" + if len(time_arrs) == 3: # H/M/S + hours, minutes, seconds = time_arrs + elif len(time_arrs) == 2: # H/M + hours, minutes = time_arrs + elif len(time_arrs[0]) > 0: + hours = f"{time_arrs[0]}点" # H + if len(time_arrs) > 1: + hours, minutes, seconds = ( + nonZero(hours, "点", f_currency), + nonZero(minutes, "分", f_currency), + nonZero(seconds, "秒", f_currency), + ) + hours_minutes_seconds = re.sub( + r"([点|分|秒])+", r"\1", f"{hours}{minutes}{seconds}" + ) + output_date = f"{year_month_day}{hours_minutes_seconds}" + return output_date + + # 【SSML】number=中文大写数字读法(单字) + # Chinese Numbers(single word) + def to_chinese_number(self, num: str): + pattern = r"(\d+)" + zh_numerals = self._zh_numerals_number + arrs = re.split(pattern, num) + output = "" + for item in arrs: + if re.match(pattern, item): + output += "".join( + zh_numerals[digit] if digit in zh_numerals else "" + for digit in str(item) + ) + else: + output += item + output = output.replace(".", "点") + return output + + # 【SSML】telephone=数字转成中文电话号码大写汉字(单字) + # Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word) + def to_chinese_telephone(self, num: str): + output = self.to_chinese_number(num.replace("+86", "")) # zh +86 + output = output.replace("一", "幺") + return output + + # 【SSML】currency=按金额发音。 + # Digital processing from GPT_SoVITS num.py (thanks) + def to_chinese_currency(self, num: str): + pattern = r"(\d+)" + arrs = re.split(pattern, num) + output = "" + for item in arrs: + if re.match(pattern, item): + output += num2str(item) + else: + output += item + output = output.replace(".", "点") + return output + + # 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。 + def to_chinese_date(self, num: str): + chinese_date = self._format_chinese_data(num) + return chinese_date + + +class LangSegment: + + def __init__(self): + + self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True) + + self._text_cache = None + self._text_lasts = None + self._text_langs = None + self._lang_count = None + self._lang_eos = None + + # 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그: + # Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다 + # 你好 , 佐々木 , OK , 오빠 这些写法均支持 + self.SYMBOLS_PATTERN = r"(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)" + + # 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。 + # 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다. + # 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。 + # The language filter group function allows you to specify reserved languages. + # Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like. + # 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。 + + # 系统默认过滤器。System default filter。(ISO 639-1 codes given) + # ---------------------------------------------------------------------------------------------------------------------------------- + # "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian + # "th"泰语=Thai + # ---------------------------------------------------------------------------------------------------------------------------------- + self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"] + + # 用户可自定义过滤器。User-defined filters + self.Langfilters = self.DEFAULT_FILTERS[:] # 创建副本 + + # 合并文本 + self.isLangMerge = True + + # 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese. + # 请使用API启用:self.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。 + + # 预览版功能,自动启用或禁用,无需设置 + # Preview feature, automatically enabled or disabled, no settings required + self.EnablePreview = False + + # 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。 + # In addition to that, it supports abbreviation filters, allowing for any combination of different languages. + # 示例:您可以任意指定多种组合,进行过滤 + # Example: You can specify any combination to filter + + # 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n + # 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다. + # 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n + # Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n + # Only the common characters between Chinese and Japanese are processed with confidence and priority. \n + self.LangPriorityThreshold = 0.89 + + # Langfilters = ["zh"] # 按中文识别 + # Langfilters = ["en"] # 按英文识别 + # Langfilters = ["ja"] # 按日文识别 + # Langfilters = ["ko"] # 按韩文识别 + # Langfilters = ["zh_ja"] # 中日混合识别 + # Langfilters = ["zh_en"] # 中英混合识别 + # Langfilters = ["ja_en"] # 日英混合识别 + # Langfilters = ["zh_ko"] # 中韩混合识别 + # Langfilters = ["ja_ko"] # 日韩混合识别 + # Langfilters = ["en_ko"] # 英韩混合识别 + # Langfilters = ["zh_ja_en"] # 中日英混合识别 + # Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别 + + # 更多过滤组合,请您随意。。。For more filter combinations, please feel free to...... + # より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. ..... + + # 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。 + # 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。 + self.keepPinyin = False + + # DEFINITION + self.PARSE_TAG = re.compile(r"(⑥\$*\d+[\d]{6,}⑥)") + + self.LangSSML = LangSSML() + + def _clears(self): + self._text_cache = None + self._text_lasts = None + self._text_langs = None + self._text_waits = None + self._lang_count = None + self._lang_eos = None + + def _is_english_word(self, word): + return bool(re.match(r"^[a-zA-Z]+$", word)) + + def _is_chinese(self, word): + for char in word: + if "\u4e00" <= char <= "\u9fff": + return True + return False + + def _is_japanese_kana(self, word): + pattern = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]+") + matches = pattern.findall(word) + return len(matches) > 0 + + def _insert_english_uppercase(self, word): + modified_text = re.sub(r"(? 0 else None + if symbol is not None: + pass + elif preData is not None and preData["symbol"] is None: + if len(clear_text) == 0: + language = preData["lang"] + elif is_number: + language = preData["lang"] + _, pre_is_number = self._clear_text_number(preData["text"]) + if preData["lang"] == language: + self._statistics(preData["lang"], text) + text = preData["text"] + text + preData["text"] = text + return preData + elif pre_is_number: + text = f'{preData["text"]}{text}' + words.pop() + elif is_number: + priority_language = self._get_filters_string()[:2] + if priority_language in "ja-zh-en-ko-fr-vi": + language = priority_language + data = {"lang": language, "text": text, "score": score, "symbol": symbol} + filters = self.Langfilters + if ( + filters is None + or len(filters) == 0 + or "?" in language + or language in filters + or language in filters[0] + or filters[0] == "*" + or filters[0] in "alls-mixs-autos" + ): + words.append(data) + self._statistics(data["lang"], data["text"]) + return data + + def _addwords(self, words, language, text, score, symbol=None): + if text == "\n": + pass # Keep Line Breaks + elif text is None or len(text.strip()) == 0: + return True + if language is None: + language = "" + language = language.lower() + if language == "en": + text = self._insert_english_uppercase(text) + # text = re.sub(r'[(())]', ',' , text) # Keep it. + text_waits = self._text_waits + ispre_waits = len(text_waits) > 0 + preResult = text_waits.pop() if ispre_waits else None + if preResult is None: + preResult = words[-1] if len(words) > 0 else None + if preResult and ("|" in preResult["lang"]): + pre_lang = preResult["lang"] + if language in pre_lang: + preResult["lang"] = language = language.split("|")[0] + else: + preResult["lang"] = pre_lang.split("|")[0] + if ispre_waits: + preResult = self._saveData( + words, + preResult["lang"], + preResult["text"], + preResult["score"], + preResult["symbol"], + ) + pre_lang = preResult["lang"] if preResult else None + if ("|" in language) and ( + pre_lang and pre_lang not in language and "…" not in language + ): + language = language.split("|")[0] + if "|" in language: + self._text_waits.append( + {"lang": language, "text": text, "score": score, "symbol": symbol} + ) + else: + self._saveData(words, language, text, score, symbol) + return False + + def _get_prev_data(self, words): + data = words[-1] if words and len(words) > 0 else None + if data: + return (data["lang"], data["text"]) + return (None, "") + + def _match_ending(self, input, index): + if input is None or len(input) == 0: + return False, None + input = re.sub(r"\s+", "", input) + if len(input) == 0 or abs(index) > len(input): + return False, None + ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])') + return ending_pattern.match(input[index]), input[index] + + def _cleans_text(self, cleans_text): + cleans_text = re.sub(r"(.*?)([^\w]+)", r"\1 ", cleans_text) + cleans_text = re.sub(r"(.)\1+", r"\1", cleans_text) + return cleans_text.strip() + + def _mean_processing(self, text: str): + if text is None or (text.strip()) == "": + return None, 0.0 + arrs = self._split_camel_case(text).split(" ") + langs = [] + for t in arrs: + if len(t.strip()) <= 3: + continue + language, score = self.langid.classify(t) + langs.append({"lang": language}) + if len(langs) == 0: + return None, 0.0 + return Counter([item["lang"] for item in langs]).most_common(1)[0][0], 1.0 + + def _lang_classify(self, cleans_text): + language, score = self.langid.classify(cleans_text) + # fix: Huggingface is np.float32 + if ( + score is not None + and isinstance(score, np.generic) + and hasattr(score, "item") + ): + score = score.item() + score = round(score, 3) + return language, score + + def _get_filters_string(self): + filters = self.Langfilters + return "-".join(filters).lower().strip() if filters is not None else "" + + def _parse_language(self, words, segment): + LANG_JA = "ja" + LANG_ZH = "zh" + LANG_ZH_JA = f"{LANG_ZH}|{LANG_JA}" + LANG_JA_ZH = f"{LANG_JA}|{LANG_ZH}" + language = LANG_ZH + regex_pattern = re.compile(r"([^\w\s]+)") + lines = regex_pattern.split(segment) + lines_max = len(lines) + LANG_EOS = self._lang_eos + for index, text in enumerate(lines): + if len(text) == 0: + continue + EOS = index >= (lines_max - 1) + nextId = index + 1 + nextText = lines[nextId] if not EOS else "" + nextPunc = ( + len(re.sub(regex_pattern, "", re.sub(r"\n+", "", nextText)).strip()) + == 0 + ) + textPunc = ( + len(re.sub(regex_pattern, "", re.sub(r"\n+", "", text)).strip()) == 0 + ) + if not EOS and ( + textPunc or (len(nextText.strip()) >= 0 and nextPunc) + ): + lines[nextId] = f"{text}{nextText}" + continue + number_tags = re.compile(r"(⑥\d{6,}⑥)") + cleans_text = re.sub(number_tags, "", text) + cleans_text = re.sub(r"\d+", "", cleans_text) + cleans_text = self._cleans_text(cleans_text) + # fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer. + if not EOS and len(cleans_text) <= 2: + lines[nextId] = f"{text}{nextText}" + continue + language, score = self._lang_classify(cleans_text) + prev_language, prev_text = self._get_prev_data(words) + if language != LANG_ZH and all( + "\u4e00" <= c <= "\u9fff" for c in re.sub(r"\s", "", cleans_text) + ): + language, score = LANG_ZH, 1 + if len(cleans_text) <= 5 and self._is_chinese(cleans_text): + filters_string = self._get_filters_string() + if score < self.LangPriorityThreshold and len(filters_string) > 0: + index_ja, index_zh = filters_string.find( + LANG_JA + ), filters_string.find(LANG_ZH) + if index_ja != -1 and index_ja < index_zh: + language = LANG_JA + elif index_zh != -1 and index_zh < index_ja: + language = LANG_ZH + if self._is_japanese_kana(cleans_text): + language = LANG_JA + elif len(cleans_text) > 2 and score > 0.90: + pass + elif EOS and LANG_EOS: + language = LANG_ZH if len(cleans_text) <= 1 else language + else: + LANG_UNKNOWN = ( + LANG_ZH_JA + if language == LANG_ZH + or (len(cleans_text) <= 2 and prev_language == LANG_ZH) + else LANG_JA_ZH + ) + match_end, match_char = self._match_ending(text, -1) + referen = ( + prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language + if prev_language + else False + ) + if match_char in "。.": + language = ( + prev_language if referen and len(words) > 0 else language + ) + else: + language = f"{LANG_UNKNOWN}|…" + text, *_ = re.subn(number_tags, self._restore_number, text) + self._addwords(words, language, text, score) + + # ---------------------------------------------------------- + # 【SSML】中文数字处理:Chinese Number Processing (SSML support) + # 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如: + # The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example: + # 中文电话号码:1234567 + # 中文数字号码:1234567 + def _process_symbol_SSML(self, words, data): + tag, match = data + language = SSML = match[1] + text = match[2] + score = 1.0 + if SSML == "telephone": + # 中文-电话号码 + language = "zh" + text = self.LangSSML.to_chinese_telephone(text) + elif SSML == "number": + # 中文-数字读法 + language = "zh" + text = self.LangSSML.to_chinese_number(text) + elif SSML == "currency": + # 中文-按金额发音 + language = "zh" + text = self.LangSSML.to_chinese_currency(text) + elif SSML == "date": + # 中文-按金额发音 + language = "zh" + text = self.LangSSML.to_chinese_date(text) + self._addwords(words, language, text, score, SSML) + + # ---------------------------------------------------------- + def _restore_number(self, matche): + value = matche.group(0) + text_cache = self._text_cache + if value in text_cache: + process, data = text_cache[value] + tag, match = data + value = match + return value + + def _pattern_symbols(self, item, text): + if text is None: + return text + tag, pattern, process = item + matches = pattern.findall(text) + if len(matches) == 1 and "".join(matches[0]) == text: + return text + for i, match in enumerate(matches): + key = f"⑥{tag}{i:06d}⑥" + text = re.sub(pattern, key, text, count=1) + self._text_cache[key] = (process, (tag, match)) + return text + + def _process_symbol(self, words, data): + tag, match = data + language = match[1] + text = match[2] + score = 1.0 + filters = self._get_filters_string() + if language not in filters: + self._process_symbol_SSML(words, data) + else: + self._addwords(words, language, text, score, True) + + def _process_english(self, words, data): + tag, match = data + text = match[0] + filters = self._get_filters_string() + priority_language = filters[:2] + # Preview feature, other language segmentation processing + enablePreview = self.EnablePreview + if enablePreview: + # Experimental: Other language support + regex_pattern = re.compile(r"(.*?[。.??!!]+[\n]{,1})") + lines = regex_pattern.split(text) + for index, text in enumerate(lines): + if len(text.strip()) == 0: + continue + cleans_text = self._cleans_text(text) + language, score = self._lang_classify(cleans_text) + if language not in filters: + language, score = self._mean_processing(cleans_text) + if language is None or score <= 0.0: + continue + elif language in filters: + pass # pass + elif score >= 0.95: + continue # High score, but not in the filter, excluded. + elif score <= 0.15 and filters[:2] == "fr": + language = priority_language + else: + language = "en" + self._addwords(words, language, text, score) + else: + # Default is English + language, score = "en", 1.0 + self._addwords(words, language, text, score) + + def _process_Russian(self, words, data): + tag, match = data + text = match[0] + language = "ru" + score = 1.0 + self._addwords(words, language, text, score) + + def _process_Thai(self, words, data): + tag, match = data + text = match[0] + language = "th" + score = 1.0 + self._addwords(words, language, text, score) + + def _process_korean(self, words, data): + tag, match = data + text = match[0] + language = "ko" + score = 1.0 + self._addwords(words, language, text, score) + + def _process_quotes(self, words, data): + tag, match = data + text = "".join(match) + childs = self.PARSE_TAG.findall(text) + if len(childs) > 0: + self._process_tags(words, text, False) + else: + cleans_text = self._cleans_text(match[1]) + if len(cleans_text) <= 5: + self._parse_language(words, text) + else: + language, score = self._lang_classify(cleans_text) + self._addwords(words, language, text, score) + + def _process_pinyin(self, words, data): + tag, match = data + text = match + language = "zh" + score = 1.0 + self._addwords(words, language, text, score) + + def _process_number(self, words, data): # "$0" process only + """ + Numbers alone cannot accurately identify language. + Because numbers are universal in all languages. + So it won't be executed here, just for testing. + """ + tag, match = data + language = words[0]["lang"] if len(words) > 0 else "zh" + text = match + score = 0.0 + self._addwords(words, language, text, score) + + def _process_tags(self, words, text, root_tag): + text_cache = self._text_cache + segments = re.split(self.PARSE_TAG, text) + segments_len = len(segments) - 1 + for index, text in enumerate(segments): + if root_tag: + self._lang_eos = index >= segments_len + if self.PARSE_TAG.match(text): + process, data = text_cache[text] + if process: + process(words, data) + else: + self._parse_language(words, text) + return words + + def _merge_results(self, words): + new_word = [] + for index, cur_data in enumerate(words): + if "symbol" in cur_data: + del cur_data["symbol"] + if index == 0: + new_word.append(cur_data) + else: + pre_data = new_word[-1] + if cur_data["lang"] == pre_data["lang"]: + pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}' + else: + new_word.append(cur_data) + return new_word + + def _parse_symbols(self, text): + TAG_NUM = "00" # "00" => default channels , "$0" => testing channel + TAG_S1, TAG_S2, TAG_P1, TAG_P2, TAG_EN, TAG_KO, TAG_RU, TAG_TH = ( + "$1", + "$2", + "$3", + "$4", + "$5", + "$6", + "$7", + "$8", + ) + TAG_BASE = re.compile(r'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)') + # Get custom language filter + filters = self.Langfilters + filters = filters if filters is not None else "" + # ======================================================================================================= + # Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues. + # 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。 + # If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols. + # S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants. + # Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu. + # ------------------------------------------------------------------------------------------------------- + # Preview feature, other language support + enablePreview = self.EnablePreview + if "fr" in filters or "vi" in filters: + enablePreview = True + self.EnablePreview = enablePreview + # 实验性:法语字符支持。Prise en charge des caractères français + RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ" + # 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt + RE_VI = ( + "" + if not enablePreview + else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ" + ) + # ------------------------------------------------------------------------------------------------------- + # Basic options: + process_list = [ + ( + TAG_S1, + re.compile(self.SYMBOLS_PATTERN), + self._process_symbol, + ), # Symbol Tag + ( + TAG_KO, + re.compile(re.sub(r"LANGUAGE", "\uac00-\ud7a3", TAG_BASE.pattern)), + self._process_korean, + ), # Korean words + ( + TAG_TH, + re.compile(re.sub(r"LANGUAGE", "\u0e00-\u0e7f", TAG_BASE.pattern)), + self._process_Thai, + ), # Thai words support. + ( + TAG_RU, + re.compile(re.sub(r"LANGUAGE", "А-Яа-яЁё", TAG_BASE.pattern)), + self._process_Russian, + ), # Russian words support. + ( + TAG_NUM, + re.compile(r"(\W*\d+\W+\d*\W*\d*)"), + self._process_number, + ), # Number words, Universal in all languages, Ignore it. + ( + TAG_EN, + re.compile( + re.sub(r"LANGUAGE", f"a-zA-Z{RE_FR}{RE_VI}", TAG_BASE.pattern) + ), + self._process_english, + ), # English words + Other language support. + ( + TAG_P1, + re.compile(r'(["\'])(.*?)(\1)'), + self._process_quotes, + ), # Regular quotes + ( + TAG_P2, + re.compile( + r"([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})" + ), + self._process_quotes, + ), # Special quotes, There are left and right. + ] + # Extended options: Default False + if self.keepPinyin: + process_list.insert( + 1, + ( + TAG_S2, + re.compile(r"([\(({](?:\s*\w*\d\w*\s*)+[})\)])"), + self._process_pinyin, + ), # Chinese Pinyin Tag. + ) + # ------------------------------------------------------------------------------------------------------- + words = [] + lines = re.findall(r".*\n*", re.sub(self.PARSE_TAG, "", text)) + for index, text in enumerate(lines): + if len(text.strip()) == 0: + continue + self._lang_eos = False + self._text_cache = {} + for item in process_list: + text = self._pattern_symbols(item, text) + cur_word = self._process_tags([], text, True) + if len(cur_word) == 0: + continue + cur_data = cur_word[0] if len(cur_word) > 0 else None + pre_data = words[-1] if len(words) > 0 else None + if ( + cur_data + and pre_data + and cur_data["lang"] == pre_data["lang"] + and not cur_data["symbol"] + and pre_data["symbol"] + ): + cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}' + words.pop() + words += cur_word + if self.isLangMerge: + words = self._merge_results(words) + lang_count = self._lang_count + if lang_count and len(lang_count) > 0: + lang_count = dict( + sorted(lang_count.items(), key=lambda x: x[1], reverse=True) + ) + lang_count = list(lang_count.items()) + self._lang_count = lang_count + return words + + def setfilters(self, filters): + # 当过滤器更改时,清除缓存 + # 필터가 변경되면 캐시를 지웁니다. + # フィルタが変更されると、キャッシュがクリアされます + # When the filter changes, clear the cache + if self.Langfilters != filters: + self._clears() + self.Langfilters = filters + + def getfilters(self): + return self.Langfilters + + def setPriorityThreshold(self, threshold: float): + self.LangPriorityThreshold = threshold + + def getPriorityThreshold(self): + return self.LangPriorityThreshold + + def getCounts(self): + lang_count = self._lang_count + if lang_count is not None: + return lang_count + text_langs = self._text_langs + if text_langs is None or len(text_langs) == 0: + return [("zh", 0)] + lang_counts = defaultdict(int) + for d in text_langs: + lang_counts[d["lang"]] += ( + int(len(d["text"]) * 2) if d["lang"] == "zh" else len(d["text"]) + ) + lang_counts = dict( + sorted(lang_counts.items(), key=lambda x: x[1], reverse=True) + ) + lang_counts = list(lang_counts.items()) + self._lang_count = lang_counts + return lang_counts + + def getTexts(self, text: str): + if text is None or len(text.strip()) == 0: + self._clears() + return [] + # lasts + text_langs = self._text_langs + if self._text_lasts == text and text_langs is not None: + return text_langs + # parse + self._text_waits = [] + self._lang_count = None + self._text_lasts = text + text = self._parse_symbols(text) + self._text_langs = text + return text + + def classify(self, text: str): + return self.getTexts(text) + + +default = [ + "af", + "am", + "an", + "ar", + "as", + "az", + "be", + "bg", + "bn", + "br", + "bs", + "ca", + "cs", + "cy", + "da", + "de", + "dz", + "el", + "en", + "eo", + "es", + "et", + "eu", + "fa", + "fi", + "fo", + "fr", + "ga", + "gl", + "gu", + "he", + "hi", + "hr", + "ht", + "hu", + "hy", + "id", + "is", + "it", + "ja", + "jv", + "ka", + "kk", + "km", + "kn", + "ko", + "ku", + "ky", + "la", + "lb", + "lo", + "lt", + "lv", + "mg", + "mk", + "ml", + "mn", + "mr", + "ms", + "mt", + "nb", + "ne", + "nl", + "nn", + "no", + "oc", + "or", + "pa", + "pl", + "ps", + "pt", + "qu", + "ro", + "ru", + "rw", + "se", + "si", + "sk", + "sl", + "sq", + "sr", + "sv", + "sw", + "ta", + "te", + "th", + "tl", + "tr", + "ug", + "uk", + "ur", + "vi", + "vo", + "wa", + "xh", + "zh", + "zu", +] \ No newline at end of file diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py b/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py new file mode 100644 index 00000000..5a133f1b --- /dev/null +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py @@ -0,0 +1,690 @@ +import logging +import os +import re +import textwrap + +import pypinyin +from hangul_romanize import Transliter +from hangul_romanize.rule import academic +from num2words import num2words +from spacy.lang.ar import Arabic +from spacy.lang.en import English +from spacy.lang.es import Spanish +from spacy.lang.ja import Japanese +from spacy.lang.zh import Chinese +from tokenizers import Tokenizer +from .zh_num2words import TextNorm + + +logger = logging.getLogger(__name__) + + +SUPPORT_LANGUAGES = { + "en": 259, + "de": 260, + "fr": 262, + "es": 284, + "it": 285, + "pt": 286, + "pl": 294, + "tr": 295, + "ru": 267, + "cs": 293, + "nl": 297, + "ar": 5022, + "zh": 5023, + "ja": 5412, + "hu": 5753, + "ko": 6152, + "hi": 6680, +} + +structure_pattern = re.compile(r"\[.*?\]") + + +# copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/tokenizer.py +def get_spacy_lang(lang): + if lang == "zh": + return Chinese() + elif lang == "ja": + return Japanese() + elif lang == "ar": + return Arabic() + elif lang == "es": + return Spanish() + else: + # For most languages, Enlish does the job + return English() + + +def split_sentence(text, lang, text_split_length=250): + """Preprocess the input text""" + text_splits = [] + if text_split_length is not None and len(text) >= text_split_length: + text_splits.append("") + nlp = get_spacy_lang(lang) + nlp.add_pipe("sentencizer") + doc = nlp(text) + for sentence in doc.sents: + if len(text_splits[-1]) + len(str(sentence)) <= text_split_length: + # if the last sentence + the current sentence is less than the text_split_length + # then add the current sentence to the last sentence + text_splits[-1] += " " + str(sentence) + text_splits[-1] = text_splits[-1].lstrip() + elif len(str(sentence)) > text_split_length: + # if the current sentence is greater than the text_split_length + for line in textwrap.wrap( + str(sentence), + width=text_split_length, + drop_whitespace=True, + break_on_hyphens=False, + tabsize=1, + ): + text_splits.append(str(line)) + else: + text_splits.append(str(sentence)) + + if len(text_splits) > 1: + if text_splits[0] == "": + del text_splits[0] + else: + text_splits = [text.lstrip()] + + return text_splits + + +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = { + "en": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] + ], + "es": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("sra", "señora"), + ("sr", "señor"), + ("dr", "doctor"), + ("dra", "doctora"), + ("st", "santo"), + ("co", "compañía"), + ("jr", "junior"), + ("ltd", "limitada"), + ] + ], + "fr": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mme", "madame"), + ("mr", "monsieur"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "compagnie"), + ("jr", "junior"), + ("ltd", "limitée"), + ] + ], + "de": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("fr", "frau"), + ("dr", "doktor"), + ("st", "sankt"), + ("co", "firma"), + ("jr", "junior"), + ] + ], + "pt": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("sra", "senhora"), + ("sr", "senhor"), + ("dr", "doutor"), + ("dra", "doutora"), + ("st", "santo"), + ("co", "companhia"), + ("jr", "júnior"), + ("ltd", "limitada"), + ] + ], + "it": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # ("sig.ra", "signora"), + ("sig", "signore"), + ("dr", "dottore"), + ("st", "santo"), + ("co", "compagnia"), + ("jr", "junior"), + ("ltd", "limitata"), + ] + ], + "pl": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("p", "pani"), + ("m", "pan"), + ("dr", "doktor"), + ("sw", "święty"), + ("jr", "junior"), + ] + ], + "ar": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # There are not many common abbreviations in Arabic as in English. + ] + ], + "zh": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], + "cs": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dr", "doktor"), # doctor + ("ing", "inženýr"), # engineer + ("p", "pan"), # Could also map to pani for woman but no easy way to do it + # Other abbreviations would be specialized and not as common. + ] + ], + "ru": [ + (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("г-жа", "госпожа"), # Mrs. + ("г-н", "господин"), # Mr. + ("д-р", "доктор"), # doctor + # Other abbreviations are less common or specialized. + ] + ], + "nl": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dhr", "de heer"), # Mr. + ("mevr", "mevrouw"), # Mrs. + ("dr", "dokter"), # doctor + ("jhr", "jonkheer"), # young lord or nobleman + # Dutch uses more abbreviations, but these are the most common ones. + ] + ], + "tr": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("b", "bay"), # Mr. + ("byk", "büyük"), # büyük + ("dr", "doktor"), # doctor + # Add other Turkish abbreviations here if needed. + ] + ], + "hu": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dr", "doktor"), # doctor + ("b", "bácsi"), # Mr. + ("nőv", "nővér"), # nurse + # Add other Hungarian abbreviations here if needed. + ] + ], + "ko": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Korean doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], +} + + +def expand_abbreviations_multilingual(text, lang="en"): + for regex, replacement in _abbreviations[lang]: + text = re.sub(regex, replacement, text) + return text + + +_symbols_multilingual = { + "en": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " and "), + ("@", " at "), + ("%", " percent "), + ("#", " hash "), + ("$", " dollar "), + ("£", " pound "), + ("°", " degree "), + ] + ], + "es": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " y "), + ("@", " arroba "), + ("%", " por ciento "), + ("#", " numeral "), + ("$", " dolar "), + ("£", " libra "), + ("°", " grados "), + ] + ], + "fr": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " et "), + ("@", " arobase "), + ("%", " pour cent "), + ("#", " dièse "), + ("$", " dollar "), + ("£", " livre "), + ("°", " degrés "), + ] + ], + "de": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " und "), + ("@", " at "), + ("%", " prozent "), + ("#", " raute "), + ("$", " dollar "), + ("£", " pfund "), + ("°", " grad "), + ] + ], + "pt": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " e "), + ("@", " arroba "), + ("%", " por cento "), + ("#", " cardinal "), + ("$", " dólar "), + ("£", " libra "), + ("°", " graus "), + ] + ], + "it": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " e "), + ("@", " chiocciola "), + ("%", " per cento "), + ("#", " cancelletto "), + ("$", " dollaro "), + ("£", " sterlina "), + ("°", " gradi "), + ] + ], + "pl": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " i "), + ("@", " małpa "), + ("%", " procent "), + ("#", " krzyżyk "), + ("$", " dolar "), + ("£", " funt "), + ("°", " stopnie "), + ] + ], + "ar": [ + # Arabic + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " و "), + ("@", " على "), + ("%", " في المئة "), + ("#", " رقم "), + ("$", " دولار "), + ("£", " جنيه "), + ("°", " درجة "), + ] + ], + "zh": [ + # Chinese + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " 和 "), + ("@", " 在 "), + ("%", " 百分之 "), + ("#", " 号 "), + ("$", " 美元 "), + ("£", " 英镑 "), + ("°", " 度 "), + ] + ], + "cs": [ + # Czech + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " a "), + ("@", " na "), + ("%", " procento "), + ("#", " křížek "), + ("$", " dolar "), + ("£", " libra "), + ("°", " stupně "), + ] + ], + "ru": [ + # Russian + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " и "), + ("@", " собака "), + ("%", " процентов "), + ("#", " номер "), + ("$", " доллар "), + ("£", " фунт "), + ("°", " градус "), + ] + ], + "nl": [ + # Dutch + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " en "), + ("@", " bij "), + ("%", " procent "), + ("#", " hekje "), + ("$", " dollar "), + ("£", " pond "), + ("°", " graden "), + ] + ], + "tr": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " ve "), + ("@", " at "), + ("%", " yüzde "), + ("#", " diyez "), + ("$", " dolar "), + ("£", " sterlin "), + ("°", " derece "), + ] + ], + "hu": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " és "), + ("@", " kukac "), + ("%", " százalék "), + ("#", " kettőskereszt "), + ("$", " dollár "), + ("£", " font "), + ("°", " fok "), + ] + ], + "ko": [ + # Korean + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " 그리고 "), + ("@", " 에 "), + ("%", " 퍼센트 "), + ("#", " 번호 "), + ("$", " 달러 "), + ("£", " 파운드 "), + ("°", " 도 "), + ] + ], +} + + +def expand_symbols_multilingual(text, lang="en"): + for regex, replacement in _symbols_multilingual[lang]: + text = re.sub(regex, replacement, text) + text = text.replace(" ", " ") # Ensure there are no double spaces + return text.strip() + + +_ordinal_re = { + "en": re.compile(r"([0-9]+)(st|nd|rd|th)"), + "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"), + "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"), + "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"), + "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"), + "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"), + "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"), + "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"), + "cs": re.compile( + r"([0-9]+)\.(?=\s|$)" + ), # In Czech, a dot is often used after the number to indicate ordinals. + "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"), + "nl": re.compile(r"([0-9]+)(de|ste|e)"), + "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"), + "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"), + "ko": re.compile(r"([0-9]+)(번째|번|차|째)"), +} +_number_re = re.compile(r"[0-9]+") +_currency_re = { + "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"), + "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"), + "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"), +} + +_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b") +_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b") +_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)") + + +def _remove_commas(m): + text = m.group(0) + if "," in text: + text = text.replace(",", "") + return text + + +def _remove_dots(m): + text = m.group(0) + if "." in text: + text = text.replace(".", "") + return text + + +def _expand_decimal_point(m, lang="en"): + amount = m.group(1).replace(",", ".") + return num2words(float(amount), lang=lang if lang != "cs" else "cz") + + +def _expand_currency(m, lang="en", currency="USD"): + amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", ".")))) + full_amount = num2words( + amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz" + ) + + and_equivalents = { + "en": ", ", + "es": " con ", + "fr": " et ", + "de": " und ", + "pt": " e ", + "it": " e ", + "pl": ", ", + "cs": ", ", + "ru": ", ", + "nl": ", ", + "ar": ", ", + "tr": ", ", + "hu": ", ", + "ko": ", ", + } + + if amount.is_integer(): + last_and = full_amount.rfind(and_equivalents[lang]) + if last_and != -1: + full_amount = full_amount[:last_and] + + return full_amount + + +def _expand_ordinal(m, lang="en"): + return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz") + + +def _expand_number(m, lang="en"): + return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz") + + +def expand_numbers_multilingual(text, lang="en"): + if lang == "zh": + text = TextNorm()(text) + else: + if lang in ["en", "ru"]: + text = re.sub(_comma_number_re, _remove_commas, text) + else: + text = re.sub(_dot_number_re, _remove_dots, text) + try: + text = re.sub( + _currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text + ) + text = re.sub( + _currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text + ) + text = re.sub( + _currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text + ) + except Exception as e: + logger.warning(f"Failed to expand currency: {e}") + if lang != "tr": + text = re.sub( + _decimal_number_re, lambda m: _expand_decimal_point(m, lang), text + ) + text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text) + text = re.sub(_number_re, lambda m: _expand_number(m, lang), text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def multilingual_cleaners(text, lang): + text = text.replace('"', "") + if lang == "tr": + text = text.replace("İ", "i") + text = text.replace("Ö", "ö") + text = text.replace("Ü", "ü") + text = lowercase(text) + try: + text = expand_numbers_multilingual(text, lang) + except Exception as e: + logger.warning(f"Failed to expand numbers: {e}") + try: + text = expand_abbreviations_multilingual(text, lang) + except Exception as e: + logger.warning(f"Failed to expand abbreviations: {e}") + try: + text = expand_symbols_multilingual(text, lang=lang) + except Exception as e: + logger.warning(f"Failed to expand symbols: {e}") + text = collapse_whitespace(text) + return text + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def chinese_transliterate(text): + return "".join( + [ + p[0] + for p in pypinyin.pinyin( + text, + style=pypinyin.Style.TONE3, + heteronym=False, + neutral_tone_with_five=True, + ) + ] + ) + + +def japanese_cleaners(text, katsu): + text = katsu.romaji(text) + text = lowercase(text) + return text + + +def korean_transliterate(text): + r = Transliter(academic) + return r.translit(text) + + +DEFAULT_VOCAB_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "vocab.json" +) + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=DEFAULT_VOCAB_FILE): + self.tokenizer = Tokenizer.from_file(vocab_file) + + def preprocess_text(self, txt, lang): + if lang in { + "ar", + "cs", + "de", + "en", + "es", + "fr", + "hu", + "it", + "nl", + "pl", + "pt", + "ru", + "tr", + "zh", + "ko", + }: + txt = multilingual_cleaners(txt, lang) + if lang == "zh": + txt = chinese_transliterate(txt) + if lang == "ko": + txt = korean_transliterate(txt) + elif lang == "ja": + txt = japanese_cleaners(txt, self.katsu) + elif lang == "hi": + # @manmay will implement this + txt = basic_cleaners(txt) + else: + raise NotImplementedError(f"Language '{lang}' is not supported.") + return txt + + def encode(self, txt, lang): + lang = lang.split("-")[0] # remove the region + txt = self.preprocess_text(txt, lang) + lang = "zh-cn" if lang == "zh" else lang + txt = f"[{lang}]{txt}" + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/num2str.py b/diffsynth_engine/models/ace_step/lyric_tokenizer/num2str.py new file mode 100644 index 00000000..efe96f40 --- /dev/null +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/num2str.py @@ -0,0 +1,325 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Digital processing from GPT_SoVITS num.py (thanks) +""" +Rules to verbalize numbers into Chinese characters. +https://zh.wikipedia.org/wiki/中文数字#現代中文 +""" + +import re +from collections import OrderedDict +from typing import List + +DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")} +UNITS = OrderedDict( + { + 1: "十", + 2: "百", + 3: "千", + 4: "万", + 8: "亿", + } +) + +COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)" + +# 分数表达式 +RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)") + + +def replace_frac(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + nominator = match.group(2) + denominator = match.group(3) + sign: str = "负" if sign else "" + nominator: str = num2str(nominator) + denominator: str = num2str(denominator) + result = f"{sign}{denominator}分之{nominator}" + return result + + +# 百分数表达式 +RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%") + + +def replace_percentage(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + percent = match.group(2) + sign: str = "负" if sign else "" + percent: str = num2str(percent) + result = f"{sign}百分之{percent}" + return result + + +# 整数表达式 +# 带负号的整数 -10 +RE_INTEGER = re.compile(r"(-)" r"(\d+)") + + +def replace_negative_num(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + number = match.group(2) + sign: str = "负" if sign else "" + number: str = num2str(number) + result = f"{sign}{number}" + return result + + +# 编号-无符号整形 +# 00078 +RE_DEFAULT_NUM = re.compile(r"\d{3}\d*") + + +def replace_default_num(match): + """ + Args: + match (re.Match) + Returns: + str + """ + number = match.group(0) + return verbalize_digit(number, alt_one=True) + + +# 加减乘除 +# RE_ASMD = re.compile( +# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))') +RE_ASMD = re.compile( + r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))" +) + +asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"} + + +def replace_asmd(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + result = match.group(1) + asmd_map[match.group(8)] + match.group(9) + return result + + +# 次方专项 +RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+") + +power_map = { + "⁰": "0", + "¹": "1", + "²": "2", + "³": "3", + "⁴": "4", + "⁵": "5", + "⁶": "6", + "⁷": "7", + "⁸": "8", + "⁹": "9", + "ˣ": "x", + "ʸ": "y", + "ⁿ": "n", +} + + +def replace_power(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + power_num = "" + for m in match.group(0): + power_num += power_map[m] + result = "的" + power_num + "次方" + return result + + +# 数字表达式 +# 纯小数 +RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))") +# 正整数 + 量词 +RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS) +RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))") + + +def replace_positive_quantifier(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + number = match.group(1) + match_2 = match.group(2) + if match_2 == "+": + match_2 = "多" + match_2: str = match_2 if match_2 else "" + quantifiers: str = match.group(3) + number: str = num2str(number) + result = f"{number}{match_2}{quantifiers}" + return result + + +def replace_number(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + sign = match.group(1) + number = match.group(2) + pure_decimal = match.group(5) + if pure_decimal: + result = num2str(pure_decimal) + else: + sign: str = "负" if sign else "" + number: str = num2str(number) + result = f"{sign}{number}" + return result + + +# 范围表达式 +# match.group(1) and match.group(8) are copy from RE_NUMBER + +RE_RANGE = re.compile( + r""" + (? str: + """ + Args: + match (re.Match) + Returns: + str + """ + first, second = match.group(1), match.group(6) + first = RE_NUMBER.sub(replace_number, first) + second = RE_NUMBER.sub(replace_number, second) + result = f"{first}到{second}" + return result + + +# ~至表达式 +RE_TO_RANGE = re.compile( + r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)" +) + + +def replace_to_range(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + result = match.group(0).replace("~", "至") + return result + + +def _get_value(value_string: str, use_zero: bool = True) -> List[str]: + stripped = value_string.lstrip("0") + if len(stripped) == 0: + return [] + elif len(stripped) == 1: + if use_zero and len(stripped) < len(value_string): + return [DIGITS["0"], DIGITS[stripped]] + else: + return [DIGITS[stripped]] + else: + largest_unit = next( + power for power in reversed(UNITS.keys()) if power < len(stripped) + ) + first_part = value_string[:-largest_unit] + second_part = value_string[-largest_unit:] + return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part) + + +def verbalize_cardinal(value_string: str) -> str: + if not value_string: + return "" + + # 000 -> '零' , 0 -> '零' + value_string = value_string.lstrip("0") + if len(value_string) == 0: + return DIGITS["0"] + + result_symbols = _get_value(value_string) + # verbalized number starting with '一十*' is abbreviated as `十*` + if ( + len(result_symbols) >= 2 + and result_symbols[0] == DIGITS["1"] + and result_symbols[1] == UNITS[1] + ): + result_symbols = result_symbols[1:] + return "".join(result_symbols) + + +def verbalize_digit(value_string: str, alt_one=False) -> str: + result_symbols = [DIGITS[digit] for digit in value_string] + result = "".join(result_symbols) + if alt_one: + result = result.replace("一", "幺") + return result + + +def num2str(value_string: str) -> str: + integer_decimal = value_string.split(".") + if len(integer_decimal) == 1: + integer = integer_decimal[0] + decimal = "" + elif len(integer_decimal) == 2: + integer, decimal = integer_decimal + else: + raise ValueError( + f"The value string: '${value_string}' has more than one point in it." + ) + + result = verbalize_cardinal(integer) + + decimal = decimal.rstrip("0") + if decimal: + # '.22' is verbalized as '零点二二' + # '3.20' is verbalized as '三点二 + result = result if result else "零" + result += "点" + verbalize_digit(decimal) + return result diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/vocab.json b/diffsynth_engine/models/ace_step/lyric_tokenizer/vocab.json new file mode 100644 index 00000000..519ed340 --- /dev/null +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/vocab.json @@ -0,0 +1,15535 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "special": true, + "content": "[STOP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 1, + "special": true, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 2, + "special": true, + "content": "[SPACE]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 259, + "special": true, + "content": "[en]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 260, + "special": true, + "content": "[de]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 261, + "special": true, + "content": "[START]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 262, + "special": true, + "content": "[fr]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 284, + "special": true, + "content": "[es]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 285, + "special": true, + "content": "[it]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 286, + "special": true, + "content": "[pt]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 294, + "special": true, + "content": "[pl]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 295, + "special": true, + "content": "[tr]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 267, + "special": true, + "content": "[ru]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 293, + "special": true, + "content": "[cs]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 297, + "special": true, + "content": "[nl]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5022, + "special": true, + "content": "[ar]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5023, + "special": true, + "content": "[zh-cn]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5412, + "special": true, + "content": "[ja]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5753, + "special": true, + "content": "[hu]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6152, + "special": true, + "content": "[ko]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6680, + "special": true, + "content": "[hi]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6681, + "special": true, + "content": "[start]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6682, + "special": true, + "content": "[intro]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6683, + "special": true, + "content": "[verse]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6684, + "special": true, + "content": "[chorus]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6685, + "special": true, + "content": "[bridge]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6686, + "special": true, + "content": "[outro]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6687, + "special": true, + "content": "[end]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6688, + "special": true, + "content": "[inst]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6689, + "special": true, + "content": "[solo]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6690, + "special": true, + "content": "[hook]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6691, + "special": true, + "content": "[pre-chorus]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6692, + "special": true, + "content": "[break]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": "[UNK]", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "vocab": { + "[STOP]": 0, + "[UNK]": 1, + "[SPACE]": 2, + "!": 3, + "'": 4, + "(": 5, + ")": 6, + ",": 7, + "-": 8, + ".": 9, + "/": 10, + ":": 11, + ";": 12, + "?": 13, + "a": 14, + "b": 15, + "c": 16, + "d": 17, + "e": 18, + "f": 19, + "g": 20, + "h": 21, + "i": 22, + "j": 23, + "k": 24, + "l": 25, + "m": 26, + "n": 27, + "o": 28, + "p": 29, + "q": 30, + "r": 31, + "s": 32, + "t": 33, + "u": 34, + "v": 35, + "w": 36, + "x": 37, + "y": 38, + "z": 39, + "th": 40, + "in": 41, + "the": 42, + "an": 43, + "er": 44, + "ou": 45, + "re": 46, + "on": 47, + "at": 48, + "ed": 49, + "en": 50, + "to": 51, + "ing": 52, + "and": 53, + "is": 54, + "as": 55, + "al": 56, + "or": 57, + "of": 58, + "ar": 59, + "it": 60, + "es": 61, + "he": 62, + "st": 63, + "le": 64, + "om": 65, + "se": 66, + "be": 67, + "ad": 68, + "ow": 69, + "ly": 70, + "ch": 71, + "wh": 72, + "that": 73, + "you": 74, + "li": 75, + "ve": 76, + "ac": 77, + "ti": 78, + "ld": 79, + "me": 80, + "was": 81, + "gh": 82, + "id": 83, + "ll": 84, + "wi": 85, + "ent": 86, + "for": 87, + "ay": 88, + "ro": 89, + "ver": 90, + "ic": 91, + "her": 92, + "ke": 93, + "his": 94, + "no": 95, + "ut": 96, + "un": 97, + "ir": 98, + "lo": 99, + "we": 100, + "ri": 101, + "ha": 102, + "with": 103, + "ght": 104, + "out": 105, + "im": 106, + "ion": 107, + "all": 108, + "ab": 109, + "one": 110, + "ne": 111, + "ge": 112, + "ould": 113, + "ter": 114, + "mo": 115, + "had": 116, + "ce": 117, + "she": 118, + "go": 119, + "sh": 120, + "ur": 121, + "am": 122, + "so": 123, + "pe": 124, + "my": 125, + "de": 126, + "are": 127, + "but": 128, + "ome": 129, + "fr": 130, + "ther": 131, + "fe": 132, + "su": 133, + "do": 134, + "con": 135, + "te": 136, + "ain": 137, + "ere": 138, + "po": 139, + "if": 140, + "they": 141, + "us": 142, + "ag": 143, + "tr": 144, + "now": 145, + "oun": 146, + "this": 147, + "have": 148, + "not": 149, + "sa": 150, + "il": 151, + "up": 152, + "thing": 153, + "from": 154, + "ap": 155, + "him": 156, + "ack": 157, + "ation": 158, + "ant": 159, + "our": 160, + "op": 161, + "like": 162, + "ust": 163, + "ess": 164, + "bo": 165, + "ok": 166, + "ul": 167, + "ind": 168, + "ex": 169, + "com": 170, + "some": 171, + "there": 172, + "ers": 173, + "co": 174, + "res": 175, + "man": 176, + "ard": 177, + "pl": 178, + "wor": 179, + "way": 180, + "tion": 181, + "fo": 182, + "ca": 183, + "were": 184, + "by": 185, + "ate": 186, + "pro": 187, + "ted": 188, + "ound": 189, + "own": 190, + "would": 191, + "ts": 192, + "what": 193, + "qu": 194, + "ally": 195, + "ight": 196, + "ck": 197, + "gr": 198, + "when": 199, + "ven": 200, + "can": 201, + "ough": 202, + "ine": 203, + "end": 204, + "per": 205, + "ous": 206, + "od": 207, + "ide": 208, + "know": 209, + "ty": 210, + "very": 211, + "si": 212, + "ak": 213, + "who": 214, + "about": 215, + "ill": 216, + "them": 217, + "est": 218, + "red": 219, + "ye": 220, + "could": 221, + "ong": 222, + "your": 223, + "their": 224, + "em": 225, + "just": 226, + "other": 227, + "into": 228, + "any": 229, + "whi": 230, + "um": 231, + "tw": 232, + "ast": 233, + "der": 234, + "did": 235, + "ie": 236, + "been": 237, + "ace": 238, + "ink": 239, + "ity": 240, + "back": 241, + "ting": 242, + "br": 243, + "more": 244, + "ake": 245, + "pp": 246, + "then": 247, + "sp": 248, + "el": 249, + "use": 250, + "bl": 251, + "said": 252, + "over": 253, + "get": 254, + "ß": 255, + "ä": 256, + "ö": 257, + "ü": 258, + "[en]": 259, + "[de]": 260, + "[START]": 261, + "[fr]": 262, + "œ": 263, + "ï": 264, + "ê": 265, + "â": 266, + "[ru]": 267, + "ÿ": 268, + "è": 269, + "à": 270, + "ë": 271, + "ù": 272, + "î": 273, + "ç": 274, + "æ": 275, + "ô": 276, + "û": 277, + "á": 278, + "é": 279, + "í": 280, + "ó": 281, + "ú": 282, + "ñ": 283, + "[es]": 284, + "[it]": 285, + "[pt]": 286, + "ń": 287, + "ś": 288, + "ę": 289, + "ą": 290, + "ż": 291, + "ć": 292, + "[cs]": 293, + "[pl]": 294, + "[tr]": 295, + "ã": 296, + "[nl]": 297, + "ş": 298, + "ğ": 299, + "ı": 300, + "ò": 301, + "ì": 302, + "¿": 303, + "…": 304, + "i̇": 305, + "õ": 306, + "\"": 307, + "´": 308, + "ø": 309, + "č": 310, + "ō": 311, + "š": 312, + "ž": 313, + "̇": 314, + "ei": 315, + "ich": 316, + "ein": 317, + "au": 318, + "sch": 319, + "und": 320, + "die": 321, + "da": 322, + "den": 323, + "gen": 324, + "zu": 325, + "hr": 326, + "ten": 327, + "mi": 328, + "sie": 329, + "das": 330, + "eine": 331, + "icht": 332, + "ber": 333, + "ach": 334, + "auf": 335, + "lich": 336, + "nicht": 337, + "mm": 338, + "ben": 339, + "war": 340, + "mit": 341, + "sich": 342, + "ig": 343, + "aus": 344, + "ist": 345, + "wie": 346, + "och": 347, + "ung": 348, + "ann": 349, + "ür": 350, + "hn": 351, + "ihr": 352, + "sen": 353, + "tz": 354, + "dem": 355, + "eit": 356, + "hat": 357, + "wir": 358, + "von": 359, + "wei": 360, + "ier": 361, + "ra": 362, + "einen": 363, + "vor": 364, + "als": 365, + "wo": 366, + "rei": 367, + "ste": 368, + "lie": 369, + "auch": 370, + "du": 371, + "des": 372, + "ko": 373, + "über": 374, + "bei": 375, + "hen": 376, + "hm": 377, + "lei": 378, + "aber": 379, + "wen": 380, + "hl": 381, + "ger": 382, + "nach": 383, + "ft": 384, + "imm": 385, + "je": 386, + "schen": 387, + "wer": 388, + "ser": 389, + "än": 390, + "sein": 391, + "ol": 392, + "cht": 393, + "für": 394, + "kl": 395, + "ff": 396, + "einem": 397, + "nen": 398, + "ja": 399, + "noch": 400, + "hatte": 401, + "pf": 402, + "hin": 403, + "di": 404, + "chen": 405, + "rü": 406, + "iel": 407, + "sel": 408, + "dass": 409, + "ihn": 410, + "mir": 411, + "schl": 412, + "ön": 413, + "gan": 414, + "gt": 415, + "einer": 416, + "sten": 417, + "mich": 418, + "wenn": 419, + "ell": 420, + "gte": 421, + "mal": 422, + "gel": 423, + "ken": 424, + "nur": 425, + "mmen": 426, + "fü": 427, + "ern": 428, + "ör": 429, + "unter": 430, + "ander": 431, + "dur": 432, + "uch": 433, + "ta": 434, + "men": 435, + "mach": 436, + "doch": 437, + "durch": 438, + "os": 439, + "gl": 440, + "hal": 441, + "ihre": 442, + "wä": 443, + "immer": 444, + "ihm": 445, + "kann": 446, + "ort": 447, + "dann": 448, + "lan": 449, + "tzt": 450, + "oder": 451, + "hren": 452, + "et": 453, + "kön": 454, + "ick": 455, + "fa": 456, + "wieder": 457, + "daß": 458, + "mein": 459, + "fen": 460, + "ganz": 461, + "diese": 462, + "ster": 463, + "dar": 464, + "wa": 465, + "ges": 466, + "na": 467, + "fl": 468, + "igen": 469, + "sche": 470, + "ungen": 471, + "mehr": 472, + "ßen": 473, + "ot": 474, + "kon": 475, + "gew": 476, + "haben": 477, + "geh": 478, + "ät": 479, + "sind": 480, + "dr": 481, + "wel": 482, + "uns": 483, + "vo": 484, + "ma": 485, + "ute": 486, + "schon": 487, + "bes": 488, + "gesch": 489, + "bt": 490, + "che": 491, + "son": 492, + "ob": 493, + "la": 494, + "rück": 495, + "seine": 496, + "kr": 497, + "fre": 498, + "eil": 499, + "zum": 500, + "hier": 501, + "kt": 502, + "ige": 503, + "spr": 504, + "leben": 505, + "bst": 506, + "zeit": 507, + "gro": 508, + "denn": 509, + "ho": 510, + "scha": 511, + "bar": 512, + "alle": 513, + "gegen": 514, + "wür": 515, + "mü": 516, + "ze": 517, + "werden": 518, + "jetzt": 519, + "kommen": 520, + "nie": 521, + "sei": 522, + "heit": 523, + "soll": 524, + "glei": 525, + "meine": 526, + "woll": 527, + "ner": 528, + "habe": 529, + "wur": 530, + "lichen": 531, + "assen": 532, + "nte": 533, + "sehen": 534, + "wird": 535, + "bis": 536, + "gar": 537, + "ien": 538, + "mus": 539, + "uß": 540, + "är": 541, + "stell": 542, + "keit": 543, + "zwei": 544, + "selbst": 545, + "sta": 546, + "pa": 547, + "sagte": 548, + "tet": 549, + "kam": 550, + "ssen": 551, + "viel": 552, + "ug": 553, + "zen": 554, + "hei": 555, + "mann": 556, + "will": 557, + "geb": 558, + "waren": 559, + "ück": 560, + "äch": 561, + "mer": 562, + "ru": 563, + "hau": 564, + "eigen": 565, + "ang": 566, + "weg": 567, + "blick": 568, + "fra": 569, + "alles": 570, + "ka": 571, + "augen": 572, + "fin": 573, + "liche": 574, + "unser": 575, + "dern": 576, + "herr": 577, + "nun": 578, + "vie": 579, + "chte": 580, + "wohl": 581, + "fall": 582, + "ht": 583, + "ün": 584, + "etwas": 585, + "stand": 586, + "äu": 587, + "mö": 588, + "tel": 589, + "rie": 590, + "dich": 591, + "dies": 592, + "hand": 593, + "bin": 594, + "ffen": 595, + "nichts": 596, + "dan": 597, + "hne": 598, + "ihnen": 599, + "esen": 600, + "dieser": 601, + "frau": 602, + "art": 603, + "dir": 604, + "isch": 605, + "erst": 606, + "gleich": 607, + "komm": 608, + "hör": 609, + "ße": 610, + "dig": 611, + "sehr": 612, + "zei": 613, + "sam": 614, + "aum": 615, + "hät": 616, + "ingen": 617, + "gut": 618, + "mut": 619, + "cken": 620, + "konnte": 621, + "stimm": 622, + "zur": 623, + "itz": 624, + "weil": 625, + "würde": 626, + "fä": 627, + "können": 628, + "keine": 629, + "fer": 630, + "ischen": 631, + "voll": 632, + "eines": 633, + "setz": 634, + "zie": 635, + "del": 636, + "tete": 637, + "seiner": 638, + "ieren": 639, + "gest": 640, + "zurück": 641, + "wurde": 642, + "schn": 643, + "pr": 644, + "ließ": 645, + "tra": 646, + "mä": 647, + "gend": 648, + "fol": 649, + "ik": 650, + "schla": 651, + "schaft": 652, + "ater": 653, + "weiß": 654, + "seinen": 655, + "lassen": 656, + "lu": 657, + "unden": 658, + "teil": 659, + "neu": 660, + "iert": 661, + "menschen": 662, + "hmen": 663, + "str": 664, + "gi": 665, + "sah": 666, + "ihren": 667, + "eln": 668, + "weiter": 669, + "gehen": 670, + "iger": 671, + "macht": 672, + "tag": 673, + "also": 674, + "halten": 675, + "nis": 676, + "acht": 677, + "geben": 678, + "og": 679, + "nat": 680, + "mar": 681, + "det": 682, + "ohne": 683, + "haus": 684, + "tro": 685, + "ange": 686, + "lau": 687, + "spiel": 688, + "tre": 689, + "schr": 690, + "inn": 691, + "los": 692, + "machen": 693, + "hätte": 694, + "beg": 695, + "wirk": 696, + "alt": 697, + "glich": 698, + "tes": 699, + "richt": 700, + "freund": 701, + "ihrer": 702, + "fel": 703, + "bel": 704, + "sol": 705, + "einmal": 706, + "eben": 707, + "hol": 708, + "hän": 709, + "tern": 710, + "hö": 711, + "schw": 712, + "recht": 713, + "wahr": 714, + "seinem": 715, + "stehen": 716, + "hlen": 717, + "ins": 718, + "ging": 719, + "wollte": 720, + "wissen": 721, + "ungs": 722, + "ald": 723, + "ass": 724, + "jahr": 725, + "mor": 726, + "welt": 727, + "under": 728, + "zusa": 729, + "kopf": 730, + "lang": 731, + "hinter": 732, + "atz": 733, + "stra": 734, + "angen": 735, + "ank": 736, + "ade": 737, + "glau": 738, + "fach": 739, + "hatten": 740, + "fort": 741, + "eicht": 742, + "iff": 743, + "ler": 744, + "mei": 745, + "diesem": 746, + "kein": 747, + "frei": 748, + "führ": 749, + "vom": 750, + "β": 751, + "ai": 752, + "ait": 753, + "que": 754, + "les": 755, + "av": 756, + "ais": 757, + "oi": 758, + "eu": 759, + "lle": 760, + "par": 761, + "ans": 762, + "ment": 763, + "ét": 764, + "une": 765, + "pas": 766, + "qui": 767, + "elle": 768, + "dé": 769, + "pour": 770, + "dans": 771, + "ré": 772, + "tou": 773, + "vous": 774, + "vi": 775, + "ouv": 776, + "mon": 777, + "sur": 778, + "ci": 779, + "plu": 780, + "ère": 781, + "mais": 782, + "ois": 783, + "plus": 784, + "ée": 785, + "aient": 786, + "mp": 787, + "lui": 788, + "ave": 789, + "était": 790, + "ses": 791, + "tout": 792, + "oir": 793, + "avait": 794, + "és": 795, + "mes": 796, + "nous": 797, + "eux": 798, + "bi": 799, + "ons": 800, + "pu": 801, + "ces": 802, + "tu": 803, + "leur": 804, + "don": 805, + "eur": 806, + "ette": 807, + "aire": 808, + "avec": 809, + "dit": 810, + "té": 811, + "ille": 812, + "comme": 813, + "cr": 814, + "ux": 815, + "ès": 816, + "aux": 817, + "jour": 818, + "ils": 819, + "bien": 820, + "cou": 821, + "quel": 822, + "peu": 823, + "cette": 824, + "cu": 825, + "mê": 826, + "fait": 827, + "gu": 828, + "être": 829, + "ité": 830, + "ens": 831, + "ni": 832, + "lé": 833, + "dis": 834, + "ble": 835, + "né": 836, + "puis": 837, + "même": 838, + "ques": 839, + "fi": 840, + "age": 841, + "moi": 842, + "ence": 843, + "ont": 844, + "main": 845, + "ors": 846, + "aut": 847, + "ance": 848, + "mé": 849, + "sans": 850, + "sé": 851, + "lon": 852, + "hom": 853, + "car": 854, + "able": 855, + "cher": 856, + "deux": 857, + "enf": 858, + "où": 859, + "ph": 860, + "ure": 861, + "temp": 862, + "pos": 863, + "rent": 864, + "pé": 865, + "faire": 866, + "pi": 867, + "tres": 868, + "ça": 869, + "endre": 870, + "bon": 871, + "sou": 872, + "int": 873, + "pré": 874, + "sent": 875, + "tant": 876, + "cer": 877, + "là": 878, + "lais": 879, + "près": 880, + "bre": 881, + "cour": 882, + "pet": 883, + "comp": 884, + "lait": 885, + "trouv": 886, + "entre": 887, + "sont": 888, + "dev": 889, + "nu": 890, + "temps": 891, + "dou": 892, + "rait": 893, + "bou": 894, + "quand": 895, + "jours": 896, + "avoir": 897, + "été": 898, + "ale": 899, + "pre": 900, + "fois": 901, + "orte": 902, + "vé": 903, + "non": 904, + "tous": 905, + "jus": 906, + "coup": 907, + "homme": 908, + "ête": 909, + "aussi": 910, + "urs": 911, + "seu": 912, + "ord": 913, + "min": 914, + "gé": 915, + "core": 916, + "va": 917, + "vre": 918, + "encore": 919, + "sem": 920, + "ite": 921, + "autre": 922, + "pris": 923, + "peut": 924, + "ue": 925, + "ante": 926, + "gn": 927, + "rép": 928, + "hu": 929, + "sion": 930, + "votre": 931, + "dire": 932, + "ez": 933, + "fem": 934, + "leurs": 935, + "met": 936, + "cri": 937, + "mis": 938, + "tour": 939, + "rai": 940, + "jam": 941, + "regar": 942, + "rien": 943, + "vers": 944, + "suis": 945, + "pouv": 946, + "vis": 947, + "grand": 948, + "ants": 949, + "cor": 950, + "rer": 951, + "cé": 952, + "tent": 953, + "pres": 954, + "vou": 955, + "alors": 956, + "sieur": 957, + "aine": 958, + "quoi": 959, + "fon": 960, + "endant": 961, + "arri": 962, + "eure": 963, + "après": 964, + "donc": 965, + "itu": 966, + "lè": 967, + "sait": 968, + "toi": 969, + "cha": 970, + "ail": 971, + "asse": 972, + "imp": 973, + "voy": 974, + "conn": 975, + "pla": 976, + "petit": 977, + "avant": 978, + "nom": 979, + "tin": 980, + "dont": 981, + "sous": 982, + "emp": 983, + "person": 984, + "elles": 985, + "beau": 986, + "parti": 987, + "cho": 988, + "prit": 989, + "toujours": 990, + "rais": 991, + "jamais": 992, + "trav": 993, + "tions": 994, + "très": 995, + "voi": 996, + "ren": 997, + "yeux": 998, + "voir": 999, + "premi": 1000, + "gne": 1001, + "heure": 1002, + "rou": 1003, + "eff": 1004, + "notre": 1005, + "ments": 1006, + "ton": 1007, + "fais": 1008, + "cela": 1009, + "répon": 1010, + "cons": 1011, + "air": 1012, + "ôt": 1013, + "pendant": 1014, + "ici": 1015, + "toute": 1016, + "jet": 1017, + "port": 1018, + "étaient": 1019, + "pen": 1020, + "hé": 1021, + "autres": 1022, + "père": 1023, + "oc": 1024, + "quelques": 1025, + "ique": 1026, + "lis": 1027, + "femme": 1028, + "jou": 1029, + "teur": 1030, + "monde": 1031, + "nes": 1032, + "dre": 1033, + "aff": 1034, + "rap": 1035, + "part": 1036, + "lement": 1037, + "cla": 1038, + "fut": 1039, + "quelque": 1040, + "prendre": 1041, + "rê": 1042, + "aille": 1043, + "sais": 1044, + "ches": 1045, + "let": 1046, + "char": 1047, + "ères": 1048, + "ents": 1049, + "moins": 1050, + "eau": 1051, + "aî": 1052, + "jeu": 1053, + "heur": 1054, + "ées": 1055, + "tri": 1056, + "point": 1057, + "mom": 1058, + "vent": 1059, + "nouv": 1060, + "gran": 1061, + "trois": 1062, + "sant": 1063, + "toutes": 1064, + "contre": 1065, + "èrent": 1066, + "chez": 1067, + "avez": 1068, + "ût": 1069, + "att": 1070, + "pau": 1071, + "porte": 1072, + "ouver": 1073, + "lit": 1074, + "prés": 1075, + "chose": 1076, + "vit": 1077, + "monsieur": 1078, + "hab": 1079, + "tête": 1080, + "ju": 1081, + "tement": 1082, + "ction": 1083, + "vrai": 1084, + "lar": 1085, + "cet": 1086, + "regard": 1087, + "lant": 1088, + "som": 1089, + "moment": 1090, + "illes": 1091, + "ple": 1092, + "ps": 1093, + "mère": 1094, + "cl": 1095, + "sour": 1096, + "ys": 1097, + "trop": 1098, + "enne": 1099, + "jusqu": 1100, + "avaient": 1101, + "avais": 1102, + "jeune": 1103, + "depuis": 1104, + "personne": 1105, + "fit": 1106, + "cert": 1107, + "jo": 1108, + "oui": 1109, + "rest": 1110, + "semb": 1111, + "cap": 1112, + "mat": 1113, + "mu": 1114, + "long": 1115, + "fran": 1116, + "faut": 1117, + "iti": 1118, + "bli": 1119, + "chev": 1120, + "pri": 1121, + "ente": 1122, + "ainsi": 1123, + "cham": 1124, + "lors": 1125, + "cas": 1126, + "ili": 1127, + "bé": 1128, + "nos": 1129, + "sui": 1130, + "rit": 1131, + "cro": 1132, + "gue": 1133, + "ía": 1134, + "por": 1135, + "las": 1136, + "ón": 1137, + "una": 1138, + "aba": 1139, + "dos": 1140, + "era": 1141, + "mb": 1142, + "para": 1143, + "ás": 1144, + "mos": 1145, + "ando": 1146, + "como": 1147, + "más": 1148, + "ción": 1149, + "tan": 1150, + "dad": 1151, + "ado": 1152, + "fu": 1153, + "cia": 1154, + "mente": 1155, + "sus": 1156, + "tar": 1157, + "za": 1158, + "ba": 1159, + "pero": 1160, + "sin": 1161, + "lla": 1162, + "án": 1163, + "ia": 1164, + "ran": 1165, + "ga": 1166, + "yo": 1167, + "tos": 1168, + "cos": 1169, + "ya": 1170, + "ones": 1171, + "había": 1172, + "hi": 1173, + "esta": 1174, + "mas": 1175, + "tor": 1176, + "aban": 1177, + "dor": 1178, + "ían": 1179, + "tas": 1180, + "én": 1181, + "endo": 1182, + "aque": 1183, + "ero": 1184, + "io": 1185, + "qué": 1186, + "cab": 1187, + "tal": 1188, + "señ": 1189, + "ora": 1190, + "todo": 1191, + "sal": 1192, + "cuando": 1193, + "gun": 1194, + "bu": 1195, + "ras": 1196, + "esto": 1197, + "pare": 1198, + "él": 1199, + "tras": 1200, + "jos": 1201, + "mien": 1202, + "pue": 1203, + "cre": 1204, + "pon": 1205, + "día": 1206, + "tros": 1207, + "sab": 1208, + "sobre": 1209, + "ese": 1210, + "mbre": 1211, + "eron": 1212, + "añ": 1213, + "ido": 1214, + "porque": 1215, + "ella": 1216, + "cen": 1217, + "muy": 1218, + "cal": 1219, + "este": 1220, + "has": 1221, + "có": 1222, + "gra": 1223, + "ros": 1224, + "aquel": 1225, + "dijo": 1226, + "cía": 1227, + "zo": 1228, + "ciones": 1229, + "mbi": 1230, + "elo": 1231, + "tó": 1232, + "ina": 1233, + "todos": 1234, + "tien": 1235, + "estaba": 1236, + "deci": 1237, + "cio": 1238, + "ño": 1239, + "lor": 1240, + "nues": 1241, + "medi": 1242, + "len": 1243, + "vida": 1244, + "ali": 1245, + "pues": 1246, + "ales": 1247, + "vol": 1248, + "mí": 1249, + "rar": 1250, + "cion": 1251, + "hasta": 1252, + "señor": 1253, + "cono": 1254, + "ah": 1255, + "dios": 1256, + "esa": 1257, + "ún": 1258, + "var": 1259, + "san": 1260, + "gui": 1261, + "otros": 1262, + "tado": 1263, + "buen": 1264, + "ña": 1265, + "tiemp": 1266, + "hacer": 1267, + "jer": 1268, + "vu": 1269, + "ana": 1270, + "así": 1271, + "antes": 1272, + "vez": 1273, + "miento": 1274, + "jar": 1275, + "lab": 1276, + "casa": 1277, + "eso": 1278, + "ego": 1279, + "dió": 1280, + "está": 1281, + "encia": 1282, + "eli": 1283, + "ías": 1284, + "tiempo": 1285, + "zar": 1286, + "van": 1287, + "mun": 1288, + "erta": 1289, + "tambi": 1290, + "sí": 1291, + "aun": 1292, + "mismo": 1293, + "entes": 1294, + "mano": 1295, + "ele": 1296, + "nada": 1297, + "segu": 1298, + "mej": 1299, + "erra": 1300, + "tir": 1301, + "uno": 1302, + "donde": 1303, + "toda": 1304, + "desde": 1305, + "también": 1306, + "cuer": 1307, + "hombre": 1308, + "otro": 1309, + "lib": 1310, + "trar": 1311, + "cual": 1312, + "hay": 1313, + "cada": 1314, + "taba": 1315, + "mento": 1316, + "tenía": 1317, + "quer": 1318, + "eran": 1319, + "siemp": 1320, + "siempre": 1321, + "erto": 1322, + "quí": 1323, + "gos": 1324, + "pués": 1325, + "ellos": 1326, + "después": 1327, + "nue": 1328, + "llo": 1329, + "inter": 1330, + "cómo": 1331, + "ahora": 1332, + "uste": 1333, + "traba": 1334, + "lado": 1335, + "ino": 1336, + "poco": 1337, + "erte": 1338, + "mujer": 1339, + "quier": 1340, + "algun": 1341, + "fue": 1342, + "ojos": 1343, + "enton": 1344, + "vos": 1345, + "esper": 1346, + "much": 1347, + "otra": 1348, + "az": 1349, + "eza": 1350, + "aquí": 1351, + "cias": 1352, + "gua": 1353, + "mucho": 1354, + "decir": 1355, + "esti": 1356, + "idad": 1357, + "algo": 1358, + "ocu": 1359, + "entonces": 1360, + "dido": 1361, + "entos": 1362, + "gri": 1363, + "dado": 1364, + "ios": 1365, + "dose": 1366, + "usted": 1367, + "quien": 1368, + "ami": 1369, + "unto": 1370, + "mejor": 1371, + "bas": 1372, + "solo": 1373, + "pregun": 1374, + "tur": 1375, + "alg": 1376, + "todas": 1377, + "parte": 1378, + "emb": 1379, + "cto": 1380, + "mundo": 1381, + "tiene": 1382, + "tante": 1383, + "palab": 1384, + "tran": 1385, + "aquella": 1386, + "cios": 1387, + "aunque": 1388, + "cuen": 1389, + "tener": 1390, + "fun": 1391, + "respon": 1392, + "allí": 1393, + "xi": 1394, + "han": 1395, + "pens": 1396, + "contra": 1397, + "tura": 1398, + "val": 1399, + "dio": 1400, + "tanto": 1401, + "camin": 1402, + "mó": 1403, + "esp": 1404, + "ada": 1405, + "ío": 1406, + "hacia": 1407, + "dej": 1408, + "estar": 1409, + "ión": 1410, + "gas": 1411, + "vas": 1412, + "noche": 1413, + "ér": 1414, + "años": 1415, + "padre": 1416, + "gus": 1417, + "ár": 1418, + "sino": 1419, + "manos": 1420, + "cido": 1421, + "estu": 1422, + "hubi": 1423, + "vir": 1424, + "bri": 1425, + "raz": 1426, + "chi": 1427, + "puede": 1428, + "menos": 1429, + "habi": 1430, + "homb": 1431, + "neces": 1432, + "may": 1433, + "eros": 1434, + "ría": 1435, + "hecho": 1436, + "escu": 1437, + "lti": 1438, + "ándo": 1439, + "bus": 1440, + "cosas": 1441, + "tú": 1442, + "espa": 1443, + "reci": 1444, + "ctor": 1445, + "prim": 1446, + "dia": 1447, + "dese": 1448, + "mientras": 1449, + "hor": 1450, + "fuer": 1451, + "ida": 1452, + "posi": 1453, + "lante": 1454, + "ano": 1455, + "estas": 1456, + "pli": 1457, + "luego": 1458, + "sión": 1459, + "cin": 1460, + "tierra": 1461, + "guar": 1462, + "cado": 1463, + "encon": 1464, + "pren": 1465, + "mayor": 1466, + "fal": 1467, + "ð": 1468, + "ħ": 1469, + "ň": 1470, + "ə": 1471, + "θ": 1472, + "’": 1473, + "“": 1474, + "”": 1475, + "zi": 1476, + "gli": 1477, + "tto": 1478, + "ono": 1479, + "nel": 1480, + "tti": 1481, + "della": 1482, + "zione": 1483, + "tta": 1484, + "tà": 1485, + "uo": 1486, + "come": 1487, + "alla": 1488, + "oni": 1489, + "ggi": 1490, + "ssi": 1491, + "più": 1492, + "ini": 1493, + "bb": 1494, + "sto": 1495, + "sono": 1496, + "eri": 1497, + "sse": 1498, + "sc": 1499, + "sul": 1500, + "vano": 1501, + "sti": 1502, + "suo": 1503, + "cchi": 1504, + "zza": 1505, + "anche": 1506, + "tte": 1507, + "sci": 1508, + "col": 1509, + "sso": 1510, + "ssa": 1511, + "dei": 1512, + "aveva": 1513, + "zz": 1514, + "amo": 1515, + "gno": 1516, + "sua": 1517, + "ria": 1518, + "sì": 1519, + "ché": 1520, + "dal": 1521, + "ona": 1522, + "spe": 1523, + "gni": 1524, + "tt": 1525, + "delle": 1526, + "questo": 1527, + "nella": 1528, + "dere": 1529, + "anno": 1530, + "dell": 1531, + "uni": 1532, + "bbe": 1533, + "anti": 1534, + "ene": 1535, + "gio": 1536, + "uto": 1537, + "qual": 1538, + "glia": 1539, + "quando": 1540, + "tutto": 1541, + "glio": 1542, + "zioni": 1543, + "cam": 1544, + "esso": 1545, + "ss": 1546, + "mol": 1547, + "loro": 1548, + "perché": 1549, + "cosa": 1550, + "due": 1551, + "poi": 1552, + "sco": 1553, + "cco": 1554, + "gna": 1555, + "tem": 1556, + "prima": 1557, + "così": 1558, + "essere": 1559, + "ani": 1560, + "bra": 1561, + "rio": 1562, + "anco": 1563, + "cui": 1564, + "spi": 1565, + "via": 1566, + "gior": 1567, + "bile": 1568, + "ggio": 1569, + "mai": 1570, + "tare": 1571, + "indi": 1572, + "rebbe": 1573, + "senza": 1574, + "zio": 1575, + "tutti": 1576, + "stato": 1577, + "zia": 1578, + "dalla": 1579, + "mia": 1580, + "vita": 1581, + "quella": 1582, + "qua": 1583, + "dove": 1584, + "allo": 1585, + "sempre": 1586, + "zzo": 1587, + "sia": 1588, + "dopo": 1589, + "porta": 1590, + "ccia": 1591, + "erano": 1592, + "anni": 1593, + "chia": 1594, + "enza": 1595, + "propri": 1596, + "anda": 1597, + "cca": 1598, + "occhi": 1599, + "questa": 1600, + "ffi": 1601, + "ron": 1602, + "mio": 1603, + "ris": 1604, + "ogni": 1605, + "rin": 1606, + "far": 1607, + "menti": 1608, + "ancora": 1609, + "fatto": 1610, + "mani": 1611, + "senti": 1612, + "pra": 1613, + "tempo": 1614, + "essi": 1615, + "bbi": 1616, + "lare": 1617, + "pers": 1618, + "sor": 1619, + "anza": 1620, + "pie": 1621, + "verso": 1622, + "altro": 1623, + "tato": 1624, + "cato": 1625, + "ato": 1626, + "volta": 1627, + "cc": 1628, + "fare": 1629, + "ciò": 1630, + "bili": 1631, + "nuo": 1632, + "quello": 1633, + "colo": 1634, + "ppo": 1635, + "trova": 1636, + "ore": 1637, + "rono": 1638, + "molto": 1639, + "almente": 1640, + "sca": 1641, + "vole": 1642, + "tali": 1643, + "sulla": 1644, + "sce": 1645, + "meno": 1646, + "anto": 1647, + "pun": 1648, + "stu": 1649, + "capi": 1650, + "giu": 1651, + "mini": 1652, + "pia": 1653, + "lavo": 1654, + "vero": 1655, + "rsi": 1656, + "altri": 1657, + "scia": 1658, + "suoi": 1659, + "glie": 1660, + "sotto": 1661, + "bene": 1662, + "scri": 1663, + "tale": 1664, + "degli": 1665, + "alc": 1666, + "uomo": 1667, + "pel": 1668, + "pote": 1669, + "essa": 1670, + "scu": 1671, + "signo": 1672, + "stro": 1673, + "uti": 1674, + "sione": 1675, + "gre": 1676, + "fini": 1677, + "lun": 1678, + "esi": 1679, + "passa": 1680, + "rà": 1681, + "mentre": 1682, + "hanno": 1683, + "usci": 1684, + "gia": 1685, + "già": 1686, + "mina": 1687, + "tica": 1688, + "giorno": 1689, + "esse": 1690, + "modo": 1691, + "spa": 1692, + "proprio": 1693, + "ori": 1694, + "contro": 1695, + "stru": 1696, + "diven": 1697, + "disse": 1698, + "rato": 1699, + "noi": 1700, + "vere": 1701, + "può": 1702, + "dice": 1703, + "cci": 1704, + "secon": 1705, + "ccio": 1706, + "qualche": 1707, + "tutta": 1708, + "gg": 1709, + "mondo": 1710, + "forma": 1711, + "mma": 1712, + "pensa": 1713, + "deva": 1714, + "fosse": 1715, + "sopra": 1716, + "tamente": 1717, + "ness": 1718, + "quanto": 1719, + "raga": 1720, + "unque": 1721, + "care": 1722, + "stre": 1723, + "grande": 1724, + "picco": 1725, + "guarda": 1726, + "nell": 1727, + "possi": 1728, + "presen": 1729, + "rò": 1730, + "paro": 1731, + "tua": 1732, + "vin": 1733, + "ane": 1734, + "stesso": 1735, + "dav": 1736, + "nei": 1737, + "nelle": 1738, + "ghi": 1739, + "pio": 1740, + "lato": 1741, + "sid": 1742, + "fine": 1743, + "fuo": 1744, + "quasi": 1745, + "ulti": 1746, + "ito": 1747, + "sue": 1748, + "fil": 1749, + "allora": 1750, + "veni": 1751, + "tano": 1752, + "ello": 1753, + "ão": 1754, + "não": 1755, + "uma": 1756, + "ela": 1757, + "lh": 1758, + "ção": 1759, + "cê": 1760, + "inha": 1761, + "você": 1762, + "ec": 1763, + "dade": 1764, + "ao": 1765, + "ram": 1766, + "vel": 1767, + "ém": 1768, + "pode": 1769, + "estava": 1770, + "isso": 1771, + "mui": 1772, + "faz": 1773, + "ões": 1774, + "pes": 1775, + "ix": 1776, + "sim": 1777, + "olh": 1778, + "isa": 1779, + "ên": 1780, + "tinha": 1781, + "meu": 1782, + "são": 1783, + "minha": 1784, + "muito": 1785, + "foi": 1786, + "bem": 1787, + "diz": 1788, + "parec": 1789, + "ço": 1790, + "pesso": 1791, + "pois": 1792, + "mesmo": 1793, + "ções": 1794, + "seus": 1795, + "até": 1796, + "ência": 1797, + "lhe": 1798, + "tiv": 1799, + "mã": 1800, + "só": 1801, + "tão": 1802, + "tudo": 1803, + "então": 1804, + "inda": 1805, + "bal": 1806, + "indo": 1807, + "ndo": 1808, + "já": 1809, + "vam": 1810, + "eito": 1811, + "depois": 1812, + "mel": 1813, + "lha": 1814, + "ainda": 1815, + "fazer": 1816, + "pou": 1817, + "pergun": 1818, + "deix": 1819, + "tamb": 1820, + "ala": 1821, + "pelo": 1822, + "também": 1823, + "fica": 1824, + "prec": 1825, + "eles": 1826, + "havia": 1827, + "lá": 1828, + "nas": 1829, + "gem": 1830, + "mem": 1831, + "ós": 1832, + "deu": 1833, + "eiro": 1834, + "..": 1835, + "assim": 1836, + "ior": 1837, + "har": 1838, + "aqui": 1839, + "cul": 1840, + "sar": 1841, + "outra": 1842, + "olhos": 1843, + "ima": 1844, + "mim": 1845, + "ago": 1846, + "pessoas": 1847, + "eram": 1848, + "eira": 1849, + "pela": 1850, + "coisa": 1851, + "mão": 1852, + "conh": 1853, + "agora": 1854, + "iam": 1855, + "há": 1856, + "suas": 1857, + "guém": 1858, + "cabe": 1859, + "nem": 1860, + "ível": 1861, + "consegu": 1862, + "trabal": 1863, + "lev": 1864, + "lem": 1865, + "vai": 1866, + "tei": 1867, + "pró": 1868, + "quem": 1869, + "onde": 1870, + "cabeça": 1871, + "nunca": 1872, + "mentos": 1873, + "hum": 1874, + "dele": 1875, + "verdade": 1876, + "tá": 1877, + "hos": 1878, + "algum": 1879, + "dizer": 1880, + "penas": 1881, + "nós": 1882, + "enquanto": 1883, + "outro": 1884, + "lho": 1885, + "melhor": 1886, + "primei": 1887, + "iu": 1888, + "apenas": 1889, + "estou": 1890, + "conte": 1891, + "homem": 1892, + "dois": 1893, + "ças": 1894, + "pouco": 1895, + "senhor": 1896, + "tando": 1897, + "espera": 1898, + "pai": 1899, + "rios": 1900, + "baix": 1901, + "ase": 1902, + "isas": 1903, + "hora": 1904, + "ficar": 1905, + "seja": 1906, + "ân": 1907, + "clar": 1908, + "inc": 1909, + "fos": 1910, + "ouvi": 1911, + "vem": 1912, + "tava": 1913, + "ário": 1914, + "sos": 1915, + "inho": 1916, + "rando": 1917, + "ês": 1918, + "coisas": 1919, + "aconte": 1920, + "lher": 1921, + "anos": 1922, + "talvez": 1923, + "estão": 1924, + "liv": 1925, + "outros": 1926, + "qualquer": 1927, + "gou": 1928, + "lí": 1929, + "tivesse": 1930, + "rado": 1931, + "precisa": 1932, + "mãe": 1933, + "dela": 1934, + "entra": 1935, + "maior": 1936, + "noite": 1937, + "tiva": 1938, + "pala": 1939, + "ração": 1940, + "deus": 1941, + "sas": 1942, + "inte": 1943, + "fei": 1944, + "palav": 1945, + "trás": 1946, + "cidade": 1947, + "lugar": 1948, + "vezes": 1949, + "encontra": 1950, + "tru": 1951, + "eci": 1952, + "ın": 1953, + "bir": 1954, + "yor": 1955, + "ek": 1956, + "dı": 1957, + "ey": 1958, + "tı": 1959, + "mı": 1960, + "iz": 1961, + "ır": 1962, + "gö": 1963, + "sı": 1964, + "bil": 1965, + "lı": 1966, + "üz": 1967, + "iç": 1968, + "iy": 1969, + "ım": 1970, + "uz": 1971, + "cak": 1972, + "iş": 1973, + "ını": 1974, + "iyor": 1975, + "baş": 1976, + "dü": 1977, + "değ": 1978, + "kar": 1979, + "ev": 1980, + "öy": 1981, + "bun": 1982, + "yap": 1983, + "sun": 1984, + "gör": 1985, + "yı": 1986, + "ki": 1987, + "ara": 1988, + "alı": 1989, + "onu": 1990, + "çı": 1991, + "şey": 1992, + "sın": 1993, + "kı": 1994, + "kad": 1995, + "ağ": 1996, + "değil": 1997, + "ük": 1998, + "çok": 1999, + "şı": 2000, + "ül": 2001, + "için": 2002, + "eye": 2003, + "oldu": 2004, + "mış": 2005, + "kal": 2006, + "mek": 2007, + "öyle": 2008, + "yordu": 2009, + "yüz": 2010, + "miş": 2011, + "mak": 2012, + "ola": 2013, + "yan": 2014, + "cek": 2015, + "yorum": 2016, + "bak": 2017, + "üm": 2018, + "ları": 2019, + "oğ": 2020, + "kadar": 2021, + "arı": 2022, + "ında": 2023, + "gün": 2024, + "yok": 2025, + "yer": 2026, + "dım": 2027, + "daha": 2028, + "ına": 2029, + "dim": 2030, + "bilir": 2031, + "iki": 2032, + "siz": 2033, + "diğ": 2034, + "bü": 2035, + "düş": 2036, + "üç": 2037, + "unu": 2038, + "aman": 2039, + "fak": 2040, + "ede": 2041, + "sonra": 2042, + "hiç": 2043, + "aki": 2044, + "ğı": 2045, + "bul": 2046, + "maz": 2047, + "anla": 2048, + "bura": 2049, + "geç": 2050, + "maya": 2051, + "konu": 2052, + "din": 2053, + "tek": 2054, + "zaman": 2055, + "eler": 2056, + "öz": 2057, + "dır": 2058, + "gibi": 2059, + "şa": 2060, + "leri": 2061, + "kim": 2062, + "ku": 2063, + "fakat": 2064, + "yar": 2065, + "göz": 2066, + "cı": 2067, + "yorsun": 2068, + "bek": 2069, + "inde": 2070, + "pek": 2071, + "bunu": 2072, + "lik": 2073, + "iler": 2074, + "edi": 2075, + "öl": 2076, + "sür": 2077, + "sır": 2078, + "çık": 2079, + "sıl": 2080, + "alar": 2081, + "kes": 2082, + "yak": 2083, + "çek": 2084, + "yıl": 2085, + "ecek": 2086, + "ız": 2087, + "git": 2088, + "kap": 2089, + "ama": 2090, + "ıl": 2091, + "ların": 2092, + "biz": 2093, + "tır": 2094, + "oy": 2095, + "ancak": 2096, + "doğ": 2097, + "bana": 2098, + "şim": 2099, + "başla": 2100, + "lü": 2101, + "madı": 2102, + "beni": 2103, + "yük": 2104, + "lık": 2105, + "beş": 2106, + "nasıl": 2107, + "tık": 2108, + "tür": 2109, + "daki": 2110, + "ceğ": 2111, + "zı": 2112, + "iyi": 2113, + "dok": 2114, + "benim": 2115, + "cağ": 2116, + "yen": 2117, + "şu": 2118, + "mez": 2119, + "düşün": 2120, + "kendi": 2121, + "şimdi": 2122, + "yol": 2123, + "yu": 2124, + "iste": 2125, + "sek": 2126, + "mam": 2127, + "söyle": 2128, + "dik": 2129, + "kur": 2130, + "olduğ": 2131, + "sını": 2132, + "biliyor": 2133, + "kan": 2134, + "yal": 2135, + "meye": 2136, + "muş": 2137, + "kaç": 2138, + "iye": 2139, + "tü": 2140, + "ef": 2141, + "tım": 2142, + "evet": 2143, + "yet": 2144, + "burada": 2145, + "tim": 2146, + "biraz": 2147, + "kor": 2148, + "doğru": 2149, + "inin": 2150, + "kız": 2151, + "diye": 2152, + "dör": 2153, + "etti": 2154, + "onun": 2155, + "isti": 2156, + "ği": 2157, + "sana": 2158, + "üş": 2159, + "arka": 2160, + "hayır": 2161, + "karşı": 2162, + "ile": 2163, + "hak": 2164, + "ıyor": 2165, + "neden": 2166, + "sev": 2167, + "sız": 2168, + "çocu": 2169, + "çalı": 2170, + "olur": 2171, + "bır": 2172, + "gir": 2173, + "ise": 2174, + "ih": 2175, + "kır": 2176, + "dön": 2177, + "böyle": 2178, + "seni": 2179, + "!\"": 2180, + "dört": 2181, + "söy": 2182, + "oş": 2183, + "musun": 2184, + "laş": 2185, + "ip": 2186, + "kay": 2187, + "hem": 2188, + "büyük": 2189, + "aç": 2190, + "bırak": 2191, + "misin": 2192, + "söz": 2193, + "değiş": 2194, + "ünü": 2195, + "gül": 2196, + "kö": 2197, + "karı": 2198, + "tamam": 2199, + "olu": 2200, + "yeni": 2201, + "lam": 2202, + "mıştı": 2203, + "yaş": 2204, + "iniz": 2205, + "kadın": 2206, + "bunun": 2207, + "mey": 2208, + "altı": 2209, + "yi": 2210, + "inden": 2211, + "senin": 2212, + "yat": 2213, + "top": 2214, + "isi": 2215, + "dün": 2216, + "hiçbir": 2217, + "yon": 2218, + "dın": 2219, + "tün": 2220, + "başka": 2221, + "hep": 2222, + "irmi": 2223, + "devam": 2224, + "olacak": 2225, + "artık": 2226, + "durum": 2227, + "imiz": 2228, + "üzel": 2229, + "lerini": 2230, + "sağ": 2231, + "gerek": 2232, + "yirmi": 2233, + "şek": 2234, + "bağ": 2235, + "lara": 2236, + "yür": 2237, + "ması": 2238, + "katı": 2239, + "dedi": 2240, + "gü": 2241, + "sorun": 2242, + "üne": 2243, + "mız": 2244, + "yapı": 2245, + "mil": 2246, + "ğını": 2247, + "tara": 2248, + "vardı": 2249, + "konuş": 2250, + "arak": 2251, + "larak": 2252, + "çocuk": 2253, + "bütün": 2254, + "ley": 2255, + "dür": 2256, + "güzel": 2257, + "ayı": 2258, + "yapa": 2259, + "nı": 2260, + "ayr": 2261, + "öne": 2262, + "yordum": 2263, + "ban": 2264, + "i̇ş": 2265, + "dum": 2266, + "yorlar": 2267, + "larını": 2268, + "çıkar": 2269, + "zan": 2270, + "seç": 2271, + "liyor": 2272, + "tak": 2273, + "şık": 2274, + "tekrar": 2275, + "aş": 2276, + "eş": 2277, + "mişti": 2278, + "kin": 2279, + "imi": 2280, + "eğ": 2281, + "gidi": 2282, + "leş": 2283, + "başladı": 2284, + "gide": 2285, + "otur": 2286, + "dde": 2287, + "ından": 2288, + "üzer": 2289, + "ının": 2290, + "nız": 2291, + "uy": 2292, + "yedi": 2293, + "kat": 2294, + "olarak": 2295, + "ladı": 2296, + "yalnız": 2297, + "bah": 2298, + "iyet": 2299, + "sak": 2300, + "açık": 2301, + "sında": 2302, + "...": 2303, + "insan": 2304, + "aynı": 2305, + "eder": 2306, + "istan": 2307, + "uzun": 2308, + "geri": 2309, + "erek": 2310, + "olan": 2311, + "gerçek": 2312, + "alan": 2313, + "dış": 2314, + "alık": 2315, + "fark": 2316, + "üst": 2317, + "sade": 2318, + "kiş": 2319, + "ldı": 2320, + "zor": 2321, + "etir": 2322, + "herkes": 2323, + "ömer": 2324, + "unda": 2325, + "haf": 2326, + "buna": 2327, + "ydı": 2328, + "peki": 2329, + "adam": 2330, + "haz": 2331, + "sına": 2332, + "kapı": 2333, + "görüş": 2334, + "sadece": 2335, + "aldı": 2336, + "geldi": 2337, + "rz": 2338, + "sz": 2339, + "cz": 2340, + "ię": 2341, + "dz": 2342, + "ał": 2343, + "się": 2344, + "rze": 2345, + "że": 2346, + "wy": 2347, + "rzy": 2348, + "ła": 2349, + "ło": 2350, + "ny": 2351, + "dzie": 2352, + "dzi": 2353, + "czy": 2354, + "cie": 2355, + "prze": 2356, + "dy": 2357, + "kie": 2358, + "ry": 2359, + "ją": 2360, + "ów": 2361, + "przy": 2362, + "mie": 2363, + "szy": 2364, + "cze": 2365, + "bie": 2366, + "cy": 2367, + "nia": 2368, + "ści": 2369, + "sze": 2370, + "jest": 2371, + "ży": 2372, + "ną": 2373, + "któ": 2374, + "ała": 2375, + "mnie": 2376, + "ły": 2377, + "cza": 2378, + "jak": 2379, + "roz": 2380, + "ró": 2381, + "zna": 2382, + "łu": 2383, + "ść": 2384, + "wia": 2385, + "wszy": 2386, + "spo": 2387, + "gdy": 2388, + "wał": 2389, + "wię": 2390, + "łem": 2391, + "ję": 2392, + "sk": 2393, + "rę": 2394, + "dob": 2395, + "już": 2396, + "bę": 2397, + "ałem": 2398, + "sza": 2399, + "pod": 2400, + "dla": 2401, + "pan": 2402, + "nę": 2403, + "może": 2404, + "śli": 2405, + "ało": 2406, + "lko": 2407, + "nych": 2408, + "powie": 2409, + "cię": 2410, + "tylko": 2411, + "naj": 2412, + "tego": 2413, + "ski": 2414, + "nego": 2415, + "wszyst": 2416, + "szcze": 2417, + "jed": 2418, + "jej": 2419, + "two": 2420, + "ąd": 2421, + "śmy": 2422, + "czę": 2423, + "wać": 2424, + "jego": 2425, + "ża": 2426, + "sy": 2427, + "praw": 2428, + "tym": 2429, + "który": 2430, + "ały": 2431, + "trze": 2432, + "niej": 2433, + "nym": 2434, + "gło": 2435, + "jąc": 2436, + "mówi": 2437, + "ska": 2438, + "nej": 2439, + "słu": 2440, + "wła": 2441, + "będzie": 2442, + "dę": 2443, + "pó": 2444, + "bez": 2445, + "nic": 2446, + "pła": 2447, + "ście": 2448, + "są": 2449, + "trzy": 2450, + "kiem": 2451, + "był": 2452, + "mog": 2453, + "robi": 2454, + "tam": 2455, + "mię": 2456, + "zy": 2457, + "pew": 2458, + "myś": 2459, + "przed": 2460, + "sko": 2461, + "które": 2462, + "lę": 2463, + "wsze": 2464, + "ąc": 2465, + "było": 2466, + "sobie": 2467, + "py": 2468, + "cią": 2469, + "jeszcze": 2470, + "tę": 2471, + "czas": 2472, + "szę": 2473, + "gł": 2474, + "kę": 2475, + "czu": 2476, + "przez": 2477, + "sło": 2478, + "wz": 2479, + "kto": 2480, + "ków": 2481, + "czo": 2482, + "liśmy": 2483, + "więc": 2484, + "rą": 2485, + "wó": 2486, + "rza": 2487, + "ności": 2488, + "wet": 2489, + "nął": 2490, + "śmie": 2491, + "nawet": 2492, + "musi": 2493, + "swo": 2494, + "tej": 2495, + "wą": 2496, + "wu": 2497, + "wią": 2498, + "niu": 2499, + "czą": 2500, + "dzo": 2501, + "skie": 2502, + "jeśli": 2503, + "czego": 2504, + "chy": 2505, + "dł": 2506, + "tych": 2507, + "bym": 2508, + "żo": 2509, + "eś": 2510, + "sią": 2511, + "kiedy": 2512, + "wró": 2513, + "dze": 2514, + "dro": 2515, + "rów": 2516, + "pani": 2517, + "kul": 2518, + "nad": 2519, + "chwi": 2520, + "nim": 2521, + "być": 2522, + "chodzi": 2523, + "nio": 2524, + "dobrze": 2525, + "teraz": 2526, + "wokul": 2527, + "coś": 2528, + "kł": 2529, + "pier": 2530, + "gdzie": 2531, + "dzy": 2532, + "pię": 2533, + "dź": 2534, + "ką": 2535, + "gó": 2536, + "zda": 2537, + "chce": 2538, + "stę": 2539, + "świa": 2540, + "wszystko": 2541, + "peł": 2542, + "wiem": 2543, + "wiel": 2544, + "każ": 2545, + "rzu": 2546, + "sły": 2547, + "jedna": 2548, + "myśl": 2549, + "mój": 2550, + "jestem": 2551, + "óż": 2552, + "miej": 2553, + "moż": 2554, + "kła": 2555, + "resz": 2556, + "dłu": 2557, + "stwo": 2558, + "nię": 2559, + "masz": 2560, + "żeby": 2561, + "niem": 2562, + "jakie": 2563, + "sty": 2564, + "nią": 2565, + "wej": 2566, + "oj": 2567, + "sła": 2568, + "ność": 2569, + "zło": 2570, + "szczę": 2571, + "lej": 2572, + "wego": 2573, + "cał": 2574, + "dział": 2575, + "kich": 2576, + "dza": 2577, + "dzię": 2578, + "oczy": 2579, + "zosta": 2580, + "czło": 2581, + "nam": 2582, + "kil": 2583, + "szu": 2584, + "wę": 2585, + "miał": 2586, + "strze": 2587, + "cej": 2588, + "ej": 2589, + "znaj": 2590, + "dać": 2591, + "miejs": 2592, + "kró": 2593, + "kry": 2594, + "bardzo": 2595, + "śnie": 2596, + "lą": 2597, + "gie": 2598, + "ciebie": 2599, + "dni": 2600, + "potrze": 2601, + "wokulski": 2602, + "uwa": 2603, + "umie": 2604, + "jednak": 2605, + "kra": 2606, + "wróci": 2607, + "człowie": 2608, + "czyć": 2609, + "była": 2610, + "żeli": 2611, + "mę": 2612, + "cę": 2613, + "zrobi": 2614, + "mogę": 2615, + "prowa": 2616, + "rem": 2617, + "niech": 2618, + "cznie": 2619, + "kro": 2620, + "tą": 2621, + "chci": 2622, + "bro": 2623, + "dzieć": 2624, + "szą": 2625, + "pad": 2626, + "trz": 2627, + "jem": 2628, + "tów": 2629, + "dru": 2630, + "taj": 2631, + "rzekł": 2632, + "niego": 2633, + "takie": 2634, + "wała": 2635, + "towa": 2636, + "kapła": 2637, + "widzi": 2638, + "podob": 2639, + "dzę": 2640, + "tał": 2641, + "stęp": 2642, + "bą": 2643, + "poko": 2644, + "wem": 2645, + "gę": 2646, + "aby": 2647, + "albo": 2648, + "spra": 2649, + "zno": 2650, + "smo": 2651, + "jesz": 2652, + "księ": 2653, + "jesteś": 2654, + "poz": 2655, + "nigdy": 2656, + "ksią": 2657, + "cóż": 2658, + "ws": 2659, + "pow": 2660, + "tka": 2661, + "świe": 2662, + "szka": 2663, + "samo": 2664, + "sł": 2665, + "rzę": 2666, + "nale": 2667, + "chcesz": 2668, + "nik": 2669, + "pę": 2670, + "chyba": 2671, + "ciąg": 2672, + "jący": 2673, + "woj": 2674, + "nasze": 2675, + "mniej": 2676, + "więcej": 2677, + "zwy": 2678, + "osta": 2679, + "waż": 2680, + "śmier": 2681, + "wier": 2682, + "dzą": 2683, + "zaś": 2684, + "gdyby": 2685, + "jaki": 2686, + "wol": 2687, + "win": 2688, + "dą": 2689, + "ścia": 2690, + "rozma": 2691, + "wal": 2692, + "panie": 2693, + "star": 2694, + "kaz": 2695, + "jeżeli": 2696, + "wra": 2697, + "koń": 2698, + "siebie": 2699, + "znowu": 2700, + "czem": 2701, + "stwa": 2702, + "isto": 2703, + "pół": 2704, + "dał": 2705, + "kobie": 2706, + "ałam": 2707, + "wych": 2708, + "cesa": 2709, + "nich": 2710, + "zawsze": 2711, + "dzić": 2712, + "też": 2713, + "lepie": 2714, + "proszę": 2715, + "kre": 2716, + "twa": 2717, + "łą": 2718, + "chu": 2719, + "cą": 2720, + "prz": 2721, + "łe": 2722, + "szedł": 2723, + "odpowie": 2724, + "myśli": 2725, + "świą": 2726, + "ź": 2727, + "ł": 2728, + "&": 2729, + "=": 2730, + "ă": 2731, + "đ": 2732, + "ţ": 2733, + "–": 2734, + "‘": 2735, + "ij": 2736, + "aa": 2737, + "een": 2738, + "het": 2739, + "aar": 2740, + "oor": 2741, + "ijn": 2742, + "dat": 2743, + "oe": 2744, + "ijk": 2745, + "aan": 2746, + "voor": 2747, + "iet": 2748, + "zijn": 2749, + "niet": 2750, + "oo": 2751, + "moet": 2752, + "heb": 2753, + "uit": 2754, + "wij": 2755, + "aat": 2756, + "lijk": 2757, + "sl": 2758, + "daar": 2759, + "deze": 2760, + "worden": 2761, + "moeten": 2762, + "onder": 2763, + "hebben": 2764, + "ook": 2765, + "ct": 2766, + "nog": 2767, + "aal": 2768, + "eer": 2769, + "bij": 2770, + "mijn": 2771, + "kom": 2772, + "atie": 2773, + "eft": 2774, + "kel": 2775, + "rij": 2776, + "heid": 2777, + "af": 2778, + "stel": 2779, + "maar": 2780, + "wee": 2781, + "heeft": 2782, + "waar": 2783, + "eren": 2784, + "wat": 2785, + "wil": 2786, + "aag": 2787, + "bet": 2788, + "hij": 2789, + "kun": 2790, + "uw": 2791, + "dt": 2792, + "door": 2793, + "tij": 2794, + "ond": 2795, + "geen": 2796, + "gev": 2797, + "veel": 2798, + "naar": 2799, + "aten": 2800, + "kunnen": 2801, + "echt": 2802, + "goe": 2803, + "twee": 2804, + "delijk": 2805, + "uur": 2806, + "toe": 2807, + "meer": 2808, + "onze": 2809, + "tijd": 2810, + "hoe": 2811, + "tot": 2812, + "zou": 2813, + "aak": 2814, + "amen": 2815, + "woor": 2816, + "wordt": 2817, + "gelijk": 2818, + "gaan": 2819, + "ker": 2820, + "eld": 2821, + "hou": 2822, + "zel": 2823, + "tegen": 2824, + "komen": 2825, + "werk": 2826, + "goed": 2827, + "zal": 2828, + "zij": 2829, + "slag": 2830, + "zien": 2831, + "echter": 2832, + "itie": 2833, + "tie": 2834, + "elijk": 2835, + "ische": 2836, + "belan": 2837, + "haar": 2838, + "vr": 2839, + "grijk": 2840, + "doen": 2841, + "land": 2842, + "belangrijk": 2843, + "open": 2844, + "ctie": 2845, + "zelf": 2846, + "mij": 2847, + "iteit": 2848, + "stem": 2849, + "mee": 2850, + "aren": 2851, + "dien": 2852, + "gaat": 2853, + "prob": 2854, + "moe": 2855, + "ullen": 2856, + "zich": 2857, + "daarom": 2858, + "orm": 2859, + "staat": 2860, + "zit": 2861, + "dui": 2862, + "dus": 2863, + "ds": 2864, + "verslag": 2865, + "kelijk": 2866, + "proble": 2867, + "schap": 2868, + "gd": 2869, + "hun": 2870, + "erd": 2871, + "zet": 2872, + "staan": 2873, + "maal": 2874, + "inder": 2875, + "eid": 2876, + "kken": 2877, + "ged": 2878, + "zullen": 2879, + "mensen": 2880, + "jaar": 2881, + "regel": 2882, + "ieder": 2883, + "volgen": 2884, + "geven": 2885, + "even": 2886, + "blij": 2887, + "ië": 2888, + "uwe": 2889, + "maken": 2890, + "oek": 2891, + "nieuwe": 2892, + "baar": 2893, + "andere": 2894, + "ruik": 2895, + "agen": 2896, + "ouw": 2897, + "willen": 2898, + "aakt": 2899, + "hoo": 2900, + "anden": 2901, + "lig": 2902, + "samen": 2903, + "zeer": 2904, + "duidelijk": 2905, + "antwoor": 2906, + "heel": 2907, + "punt": 2908, + "houden": 2909, + "vraag": 2910, + "gele": 2911, + "eens": 2912, + "besch": 2913, + "omen": 2914, + "erg": 2915, + "doel": 2916, + "dag": 2917, + "uren": 2918, + "ings": 2919, + "oren": 2920, + "delen": 2921, + "steun": 2922, + "innen": 2923, + "pol": 2924, + "oon": 2925, + "sn": 2926, + "zonder": 2927, + "nodig": 2928, + "alleen": 2929, + "mid": 2930, + "ragen": 2931, + "iets": 2932, + "versch": 2933, + "gebruik": 2934, + "rouw": 2935, + "stellen": 2936, + "menten": 2937, + "eerste": 2938, + "laat": 2939, + "groot": 2940, + "ood": 2941, + "toch": 2942, + "laten": 2943, + "aard": 2944, + "sle": 2945, + "deel": 2946, + "plaat": 2947, + "ree": 2948, + "betre": 2949, + "lid": 2950, + "uiten": 2951, + "racht": 2952, + "beleid": 2953, + "stie": 2954, + "staten": 2955, + "ggen": 2956, + "reken": 2957, + "alen": 2958, + "ming": 2959, + "mogelijk": 2960, + "grote": 2961, + "altijd": 2962, + "enkel": 2963, + "wik": 2964, + "politie": 2965, + "elk": 2966, + "handel": 2967, + "kwe": 2968, + "maat": 2969, + "elen": 2970, + "vrij": 2971, + "jes": 2972, + "aam": 2973, + "huis": 2974, + "weer": 2975, + "lidstaten": 2976, + "king": 2977, + "kle": 2978, + "bed": 2979, + "geval": 2980, + "wikkel": 2981, + "kwestie": 2982, + "stee": 2983, + "hel": 2984, + "komst": 2985, + "iden": 2986, + "eerd": 2987, + "tweede": 2988, + "probleem": 2989, + "ussen": 2990, + "snel": 2991, + "tig": 2992, + "ult": 2993, + "nemen": 2994, + "commis": 2995, + "verschil": 2996, + "zoek": 2997, + "krij": 2998, + "graag": 2999, + "denk": 3000, + "landen": 3001, + "reden": 3002, + "besl": 3003, + "oeg": 3004, + "beter": 3005, + "heden": 3006, + "mag": 3007, + "boven": 3008, + "cont": 3009, + "fd": 3010, + "hele": 3011, + "vier": 3012, + "gez": 3013, + "kw": 3014, + "aas": 3015, + "ontwikkel": 3016, + "drie": 3017, + "vaak": 3018, + "plaats": 3019, + "gang": 3020, + "ijf": 3021, + "natuur": 3022, + "tussen": 3023, + "bat": 3024, + "komt": 3025, + "wacht": 3026, + "aad": 3027, + "achter": 3028, + "gebie": 3029, + "verk": 3030, + "ligt": 3031, + "nieuw": 3032, + "vand": 3033, + "ý": 3034, + "ď": 3035, + "ě": 3036, + "ř": 3037, + "ť": 3038, + "ů": 3039, + "„": 3040, + "ní": 3041, + "ně": 3042, + "ře": 3043, + "ná": 3044, + "vě": 3045, + "vá": 3046, + "rá": 3047, + "vy": 3048, + "mě": 3049, + "ři": 3050, + "ří": 3051, + "že": 3052, + "jí": 3053, + "vý": 3054, + "ji": 3055, + "dě": 3056, + "če": 3057, + "tě": 3058, + "ky": 3059, + "še": 3060, + "ké": 3061, + "ší": 3062, + "pře": 3063, + "ví": 3064, + "ný": 3065, + "ži": 3066, + "má": 3067, + "cí": 3068, + "zá": 3069, + "ské": 3070, + "dá": 3071, + "byl": 3072, + "tí": 3073, + "pří": 3074, + "při": 3075, + "či": 3076, + "vní": 3077, + "ča": 3078, + "dí": 3079, + "dní": 3080, + "ká": 3081, + "nou": 3082, + "vět": 3083, + "pě": 3084, + "kou": 3085, + "ých": 3086, + "bě": 3087, + "prá": 3088, + "jako": 3089, + "ží": 3090, + "zí": 3091, + "jsou": 3092, + "jsem": 3093, + "lní": 3094, + "cké": 3095, + "vat": 3096, + "před": 3097, + "hla": 3098, + "stá": 3099, + "čí": 3100, + "ši": 3101, + "kla": 3102, + "ště": 3103, + "lou": 3104, + "mů": 3105, + "chá": 3106, + "pů": 3107, + "také": 3108, + "dů": 3109, + "nost": 3110, + "tře": 3111, + "sku": 3112, + "vše": 3113, + "tní": 3114, + "byla": 3115, + "ční": 3116, + "jeho": 3117, + "bý": 3118, + "vání": 3119, + "ných": 3120, + "tři": 3121, + "vz": 3122, + "stře": 3123, + "dva": 3124, + "hle": 3125, + "čá": 3126, + "nosti": 3127, + "vš": 3128, + "hra": 3129, + "jen": 3130, + "slo": 3131, + "však": 3132, + "kdy": 3133, + "bylo": 3134, + "bude": 3135, + "jší": 3136, + "vých": 3137, + "ním": 3138, + "sm": 3139, + "koli": 3140, + "rů": 3141, + "může": 3142, + "není": 3143, + "hod": 3144, + "bí": 3145, + "tý": 3146, + "stě": 3147, + "uje": 3148, + "sá": 3149, + "pět": 3150, + "krá": 3151, + "tom": 3152, + "ství": 3153, + "vně": 3154, + "sed": 3155, + "své": 3156, + "pí": 3157, + "musí": 3158, + "už": 3159, + "tím": 3160, + "jící": 3161, + "jedno": 3162, + "čas": 3163, + "čty": 3164, + "ský": 3165, + "evro": 3166, + "toho": 3167, + "hy": 3168, + "kter": 3169, + "rní": 3170, + "stí": 3171, + "svě": 3172, + "pak": 3173, + "všech": 3174, + "ků": 3175, + "ng": 3176, + "ád": 3177, + "chází": 3178, + "být": 3179, + "první": 3180, + "mno": 3181, + "ského": 3182, + "pá": 3183, + "nebo": 3184, + "kem": 3185, + "sla": 3186, + "ného": 3187, + "zde": 3188, + "další": 3189, + "řa": 3190, + "čtyři": 3191, + "hrá": 3192, + "druh": 3193, + "lně": 3194, + "vla": 3195, + "ských": 3196, + "ško": 3197, + "půso": 3198, + "proto": 3199, + "vů": 3200, + "ská": 3201, + "šest": 3202, + "dně": 3203, + "ještě": 3204, + "mezi": 3205, + "několi": 3206, + "již": 3207, + "čně": 3208, + "slu": 3209, + "zná": 3210, + "sedm": 3211, + "vlá": 3212, + "osm": 3213, + "byly": 3214, + "vám": 3215, + "cký": 3216, + "tech": 3217, + "ději": 3218, + "velmi": 3219, + "leži": 3220, + "vala": 3221, + "lý": 3222, + "tvo": 3223, + "spole": 3224, + "stup": 3225, + "mož": 3226, + "evrop": 3227, + "stal": 3228, + "jde": 3229, + "rodi": 3230, + "její": 3231, + "poli": 3232, + "devět": 3233, + "sme": 3234, + "až": 3235, + "této": 3236, + "tento": 3237, + "kaž": 3238, + "nula": 3239, + "bych": 3240, + "moc": 3241, + "stou": 3242, + "kdo": 3243, + "zd": 3244, + "praco": 3245, + "tomu": 3246, + "ným": 3247, + "živo": 3248, + "zem": 3249, + "násle": 3250, + "sky": 3251, + "jich": 3252, + "měl": 3253, + "děla": 3254, + "jsme": 3255, + "nice": 3256, + "stej": 3257, + "stní": 3258, + "náro": 3259, + "nit": 3260, + "později": 3261, + "tako": 3262, + "nce": 3263, + "čer": 3264, + "ším": 3265, + "něco": 3266, + "vál": 3267, + "řej": 3268, + "krát": 3269, + "ální": 3270, + "asi": 3271, + "které": 3272, + "stav": 3273, + "mají": 3274, + "mys": 3275, + "době": 3276, + "sně": 3277, + "zku": 3278, + "tů": 3279, + "chod": 3280, + "spě": 3281, + "jejich": 3282, + "součas": 3283, + "vali": 3284, + "kte": 3285, + "prů": 3286, + "zení": 3287, + "pat": 3288, + "potře": 3289, + "dnes": 3290, + "zemí": 3291, + "znam": 3292, + "mám": 3293, + "tedy": 3294, + "hlavní": 3295, + "použí": 3296, + "bní": 3297, + "vede": 3298, + "lep": 3299, + "jek": 3300, + "prav": 3301, + "politi": 3302, + "dne": 3303, + "čení": 3304, + "než": 3305, + "děl": 3306, + "čo": 3307, + "cích": 3308, + "sté": 3309, + "dlou": 3310, + "několik": 3311, + "vyu": 3312, + "ckých": 3313, + "nové": 3314, + "čin": 3315, + "dělá": 3316, + "ký": 3317, + "obla": 3318, + "podle": 3319, + "důleži": 3320, + "poku": 3321, + "kone": 3322, + "dý": 3323, + "dvě": 3324, + "žád": 3325, + "nout": 3326, + "tku": 3327, + "tvr": 3328, + "ckého": 3329, + "rov": 3330, + "tele": 3331, + "psa": 3332, + "svět": 3333, + "tivní": 3334, + "dosta": 3335, + "šel": 3336, + "druhé": 3337, + "skou": 3338, + "žo": 3339, + "jedná": 3340, + "význam": 3341, + "problé": 3342, + "publi": 3343, + "ván": 3344, + "odpo": 3345, + "podpo": 3346, + "dle": 3347, + "jaké": 3348, + "šení": 3349, + "vím": 3350, + "během": 3351, + "nachází": 3352, + "slou": 3353, + "pouze": 3354, + "otá": 3355, + "plo": 3356, + "tové": 3357, + "větši": 3358, + "komi": 3359, + "vají": 3360, + "tyto": 3361, + "zápa": 3362, + "změ": 3363, + "moh": 3364, + "více": 3365, + "společ": 3366, + "auto": 3367, + "proti": 3368, + "dět": 3369, + "cháze": 3370, + "žel": 3371, + "«": 3372, + "»": 3373, + "а": 3374, + "б": 3375, + "в": 3376, + "г": 3377, + "д": 3378, + "е": 3379, + "ж": 3380, + "з": 3381, + "и": 3382, + "й": 3383, + "к": 3384, + "л": 3385, + "м": 3386, + "н": 3387, + "о": 3388, + "п": 3389, + "р": 3390, + "с": 3391, + "т": 3392, + "у": 3393, + "ф": 3394, + "х": 3395, + "ц": 3396, + "ч": 3397, + "ш": 3398, + "щ": 3399, + "ъ": 3400, + "ы": 3401, + "ь": 3402, + "э": 3403, + "ю": 3404, + "я": 3405, + "ё": 3406, + "‑": 3407, + "−": 3408, + "ст": 3409, + "ен": 3410, + "но": 3411, + "на": 3412, + "пр": 3413, + "то": 3414, + "по": 3415, + "ра": 3416, + "го": 3417, + "ко": 3418, + "не": 3419, + "во": 3420, + "ва": 3421, + "ет": 3422, + "ер": 3423, + "ни": 3424, + "ел": 3425, + "ит": 3426, + "ны": 3427, + "за": 3428, + "ро": 3429, + "ени": 3430, + "ка": 3431, + "ли": 3432, + "ем": 3433, + "да": 3434, + "об": 3435, + "ла": 3436, + "до": 3437, + "ся": 3438, + "ть": 3439, + "от": 3440, + "ло": 3441, + "ль": 3442, + "ед": 3443, + "со": 3444, + "ми": 3445, + "ре": 3446, + "мо": 3447, + "ци": 3448, + "про": 3449, + "та": 3450, + "это": 3451, + "ки": 3452, + "ру": 3453, + "при": 3454, + "ти": 3455, + "се": 3456, + "ста": 3457, + "вы": 3458, + "мы": 3459, + "ви": 3460, + "бы": 3461, + "ма": 3462, + "ес": 3463, + "ля": 3464, + "сти": 3465, + "ле": 3466, + "что": 3467, + "ме": 3468, + "ри": 3469, + "ча": 3470, + "од": 3471, + "ей": 3472, + "ель": 3473, + "ения": 3474, + "га": 3475, + "ну": 3476, + "си": 3477, + "па": 3478, + "раз": 3479, + "бо": 3480, + "сто": 3481, + "су": 3482, + "са": 3483, + "ду": 3484, + "его": 3485, + "ест": 3486, + "ин": 3487, + "ить": 3488, + "из": 3489, + "же": 3490, + "му": 3491, + "пер": 3492, + "под": 3493, + "ение": 3494, + "сь": 3495, + "ку": 3496, + "пред": 3497, + "ного": 3498, + "ных": 3499, + "вер": 3500, + "те": 3501, + "ной": 3502, + "ции": 3503, + "де": 3504, + "ры": 3505, + "дел": 3506, + "лю": 3507, + "ве": 3508, + "он": 3509, + "мен": 3510, + "ги": 3511, + "ня": 3512, + "бу": 3513, + "пра": 3514, + "все": 3515, + "ется": 3516, + "сть": 3517, + "жа": 3518, + "дол": 3519, + "жи": 3520, + "бе": 3521, + "кон": 3522, + "сл": 3523, + "ши": 3524, + "ди": 3525, + "ств": 3526, + "ско": 3527, + "ные": 3528, + "чи": 3529, + "ют": 3530, + "дер": 3531, + "стра": 3532, + "ты": 3533, + "ход": 3534, + "щи": 3535, + "зо": 3536, + "зна": 3537, + "ности": 3538, + "чес": 3539, + "вля": 3540, + "вать": 3541, + "ор": 3542, + "пол": 3543, + "вет": 3544, + "так": 3545, + "ша": 3546, + "ту": 3547, + "сво": 3548, + "пре": 3549, + "она": 3550, + "итель": 3551, + "ный": 3552, + "сло": 3553, + "как": 3554, + "вл": 3555, + "ность": 3556, + "хо": 3557, + "мож": 3558, + "пе": 3559, + "для": 3560, + "ния": 3561, + "ное": 3562, + "рас": 3563, + "долж": 3564, + "дар": 3565, + "тель": 3566, + "ска": 3567, + "пу": 3568, + "ство": 3569, + "кото": 3570, + "раб": 3571, + "ее": 3572, + "род": 3573, + "эти": 3574, + "соб": 3575, + "ору": 3576, + "жен": 3577, + "ным": 3578, + "ити": 3579, + "ние": 3580, + "ком": 3581, + "дет": 3582, + "сту": 3583, + "гу": 3584, + "пи": 3585, + "меж": 3586, + "ению": 3587, + "тер": 3588, + "работ": 3589, + "воз": 3590, + "ция": 3591, + "кой": 3592, + "щест": 3593, + "гра": 3594, + "зи": 3595, + "ря": 3596, + "между": 3597, + "ства": 3598, + "вс": 3599, + "ело": 3600, + "ше": 3601, + "мер": 3602, + "ба": 3603, + "зы": 3604, + "лу": 3605, + "аль": 3606, + "дей": 3607, + "гла": 3608, + "народ": 3609, + "кти": 3610, + "предста": 3611, + "лся": 3612, + "явля": 3613, + "ски": 3614, + "нов": 3615, + "един": 3616, + "ров": 3617, + "ис": 3618, + "нима": 3619, + "рем": 3620, + "ходи": 3621, + "также": 3622, + "дру": 3623, + "ать": 3624, + "след": 3625, + "гово": 3626, + "ная": 3627, + "ющи": 3628, + "ень": 3629, + "которы": 3630, + "хот": 3631, + "ву": 3632, + "их": 3633, + "ему": 3634, + "чит": 3635, + "важ": 3636, + "орга": 3637, + "чески": 3638, + "ще": 3639, + "ке": 3640, + "ха": 3641, + "пос": 3642, + "том": 3643, + "боль": 3644, + "мне": 3645, + "пас": 3646, + "объ": 3647, + "прав": 3648, + "конф": 3649, + "слу": 3650, + "поддер": 3651, + "стви": 3652, + "наш": 3653, + "лько": 3654, + "стоя": 3655, + "ную": 3656, + "лем": 3657, + "енных": 3658, + "кра": 3659, + "ды": 3660, + "международ": 3661, + "гда": 3662, + "необ": 3663, + "госу": 3664, + "ству": 3665, + "ении": 3666, + "государ": 3667, + "кто": 3668, + "им": 3669, + "чест": 3670, + "рет": 3671, + "вопро": 3672, + "лен": 3673, + "ели": 3674, + "рова": 3675, + "ций": 3676, + "нам": 3677, + "этой": 3678, + "жения": 3679, + "необходи": 3680, + "меня": 3681, + "было": 3682, + "сили": 3683, + "фи": 3684, + "вя": 3685, + "шь": 3686, + "этого": 3687, + "они": 3688, + "органи": 3689, + "безо": 3690, + "проб": 3691, + "име": 3692, + "реш": 3693, + "би": 3694, + "безопас": 3695, + "ются": 3696, + "оста": 3697, + "енно": 3698, + "год": 3699, + "ела": 3700, + "представ": 3701, + "ться": 3702, + "слово": 3703, + "организа": 3704, + "должны": 3705, + "этом": 3706, + "бла": 3707, + "че": 3708, + "чу": 3709, + "благо": 3710, + "этому": 3711, + "врем": 3712, + "спе": 3713, + "ном": 3714, + "ений": 3715, + "спо": 3716, + "нас": 3717, + "нет": 3718, + "зу": 3719, + "вед": 3720, + "еще": 3721, + "сказа": 3722, + "сей": 3723, + "ерен": 3724, + "дан": 3725, + "сам": 3726, + "еля": 3727, + "ран": 3728, + "зыва": 3729, + "является": 3730, + "будет": 3731, + "ктив": 3732, + "тре": 3733, + "деле": 3734, + "мот": 3735, + "конферен": 3736, + "лась": 3737, + "час": 3738, + "сторо": 3739, + "кого": 3740, + "ез": 3741, + "ней": 3742, + "ос": 3743, + "лись": 3744, + "разору": 3745, + "пере": 3746, + "сси": 3747, + "ными": 3748, + "проц": 3749, + "голо": 3750, + "чело": 3751, + "боле": 3752, + "челове": 3753, + "сер": 3754, + "пл": 3755, + "чет": 3756, + "стран": 3757, + "пя": 3758, + "был": 3759, + "кла": 3760, + "тов": 3761, + "жд": 3762, + "дела": 3763, + "ера": 3764, + "уже": 3765, + "совет": 3766, + "ген": 3767, + "безопасности": 3768, + "ца": 3769, + "седа": 3770, + "поз": 3771, + "ответ": 3772, + "проблем": 3773, + "нако": 3774, + "тем": 3775, + "доста": 3776, + "пы": 3777, + "ща": 3778, + "вой": 3779, + "сущест": 3780, + "необходимо": 3781, + "быть": 3782, + "может": 3783, + "дем": 3784, + "чтобы": 3785, + "ек": 3786, + "чер": 3787, + "усили": 3788, + "рес": 3789, + "руд": 3790, + "единенных": 3791, + "доб": 3792, + "дости": 3793, + "ствен": 3794, + "ядер": 3795, + "годня": 3796, + "каза": 3797, + "сегодня": 3798, + "сейчас": 3799, + "только": 3800, + "вод": 3801, + "есь": 3802, + "много": 3803, + "буду": 3804, + "ев": 3805, + "есть": 3806, + "три": 3807, + "общест": 3808, + "явл": 3809, + "высту": 3810, + "ред": 3811, + "счит": 3812, + "сит": 3813, + "делега": 3814, + "лож": 3815, + "этот": 3816, + "фор": 3817, + "клю": 3818, + "возмож": 3819, + "вания": 3820, + "бли": 3821, + "или": 3822, + "вз": 3823, + "наций": 3824, + "ского": 3825, + "приня": 3826, + "пла": 3827, + "оч": 3828, + "иться": 3829, + "сте": 3830, + "наши": 3831, + "которые": 3832, + "ар": 3833, + "имеет": 3834, + "сот": 3835, + "знач": 3836, + "перь": 3837, + "следу": 3838, + "ены": 3839, + "таки": 3840, + "объединенных": 3841, + "стро": 3842, + "теперь": 3843, + "бле": 3844, + "благодар": 3845, + "разв": 3846, + "ан": 3847, + "жива": 3848, + "очень": 3849, + "ят": 3850, + "без": 3851, + "обес": 3852, + "гро": 3853, + "лось": 3854, + "сы": 3855, + "организации": 3856, + "член": 3857, + "того": 3858, + "ональ": 3859, + "жда": 3860, + "всех": 3861, + "свя": 3862, + "более": 3863, + "сов": 3864, + "когда": 3865, + "вот": 3866, + "кре": 3867, + "кры": 3868, + "поэтому": 3869, + "воль": 3870, + "ой": 3871, + "генера": 3872, + "чем": 3873, + "лы": 3874, + "полити": 3875, + "вен": 3876, + "конференции": 3877, + "процес": 3878, + "бя": 3879, + "ите": 3880, + "отно": 3881, + "развити": 3882, + "аф": 3883, + "ющ": 3884, + "вно": 3885, + "мир": 3886, + "нии": 3887, + "кая": 3888, + "ас": 3889, + "ительно": 3890, + "вто": 3891, + "ением": 3892, + "генераль": 3893, + "прот": 3894, + "всем": 3895, + "самбле": 3896, + "ассамбле": 3897, + "ом": 3898, + "зд": 3899, + "смот": 3900, + "реги": 3901, + "чего": 3902, + "однако": 3903, + "усилия": 3904, + "действи": 3905, + "чно": 3906, + "уча": 3907, + "образ": 3908, + "вос": 3909, + "эта": 3910, + "перего": 3911, + "говор": 3912, + "вам": 3913, + "моло": 3914, + "время": 3915, + "дь": 3916, + "хотел": 3917, + "гру": 3918, + "заявл": 3919, + "предоста": 3920, + "поль": 3921, + "нее": 3922, + "резо": 3923, + "перегово": 3924, + "резолю": 3925, + "крет": 3926, + "поддерж": 3927, + "обеспе": 3928, + "него": 3929, + "представит": 3930, + "наде": 3931, + "кри": 3932, + "чь": 3933, + "проек": 3934, + "лет": 3935, + "други": 3936, + "_": 3937, + "،": 3938, + "؛": 3939, + "؟": 3940, + "ء": 3941, + "آ": 3942, + "أ": 3943, + "ؤ": 3944, + "إ": 3945, + "ئ": 3946, + "ا": 3947, + "ب": 3948, + "ة": 3949, + "ت": 3950, + "ث": 3951, + "ج": 3952, + "ح": 3953, + "خ": 3954, + "د": 3955, + "ذ": 3956, + "ر": 3957, + "ز": 3958, + "س": 3959, + "ش": 3960, + "ص": 3961, + "ض": 3962, + "ط": 3963, + "ظ": 3964, + "ع": 3965, + "غ": 3966, + "ـ": 3967, + "ف": 3968, + "ق": 3969, + "ك": 3970, + "ل": 3971, + "م": 3972, + "ن": 3973, + "ه": 3974, + "و": 3975, + "ى": 3976, + "ي": 3977, + "ً": 3978, + "ٌ": 3979, + "ٍ": 3980, + "َ": 3981, + "ُ": 3982, + "ِ": 3983, + "ّ": 3984, + "ْ": 3985, + "ٰ": 3986, + "چ": 3987, + "ڨ": 3988, + "ک": 3989, + "ھ": 3990, + "ی": 3991, + "ۖ": 3992, + "ۗ": 3993, + "ۘ": 3994, + "ۚ": 3995, + "ۛ": 3996, + "—": 3997, + "☭": 3998, + "ﺃ": 3999, + "ﻻ": 4000, + "ال": 4001, + "َا": 4002, + "وَ": 4003, + "َّ": 4004, + "ِي": 4005, + "أَ": 4006, + "لَ": 4007, + "نَ": 4008, + "الْ": 4009, + "هُ": 4010, + "ُو": 4011, + "ما": 4012, + "نْ": 4013, + "من": 4014, + "عَ": 4015, + "نا": 4016, + "لا": 4017, + "مَ": 4018, + "تَ": 4019, + "فَ": 4020, + "أن": 4021, + "لي": 4022, + "مِ": 4023, + "ان": 4024, + "في": 4025, + "رَ": 4026, + "يَ": 4027, + "هِ": 4028, + "مْ": 4029, + "قَ": 4030, + "بِ": 4031, + "لى": 4032, + "ين": 4033, + "إِ": 4034, + "لِ": 4035, + "وا": 4036, + "كَ": 4037, + "ها": 4038, + "ًا": 4039, + "مُ": 4040, + "ون": 4041, + "الم": 4042, + "بَ": 4043, + "يا": 4044, + "ذا": 4045, + "سا": 4046, + "الل": 4047, + "مي": 4048, + "يْ": 4049, + "را": 4050, + "ري": 4051, + "لك": 4052, + "مَا": 4053, + "نَّ": 4054, + "لم": 4055, + "إن": 4056, + "ست": 4057, + "وم": 4058, + "َّا": 4059, + "لَا": 4060, + "هم": 4061, + "ِّ": 4062, + "كُ": 4063, + "كان": 4064, + "سَ": 4065, + "با": 4066, + "دي": 4067, + "حَ": 4068, + "عْ": 4069, + "بي": 4070, + "الأ": 4071, + "ول": 4072, + "فِي": 4073, + "رِ": 4074, + "دا": 4075, + "مِنْ": 4076, + "ُونَ": 4077, + "وْ": 4078, + "هَا": 4079, + "ُّ": 4080, + "الس": 4081, + "الَ": 4082, + "ني": 4083, + "لْ": 4084, + "تُ": 4085, + "هل": 4086, + "رة": 4087, + "دَ": 4088, + "سْ": 4089, + "تِ": 4090, + "نَا": 4091, + "رْ": 4092, + "اللَّ": 4093, + "سامي": 4094, + "كن": 4095, + "كل": 4096, + "هَ": 4097, + "عَلَ": 4098, + "على": 4099, + "مع": 4100, + "إلى": 4101, + "قد": 4102, + "الر": 4103, + "ُوا": 4104, + "ير": 4105, + "عن": 4106, + "يُ": 4107, + "نِ": 4108, + "بْ": 4109, + "الح": 4110, + "هُمْ": 4111, + "قا": 4112, + "ذه": 4113, + "الت": 4114, + "ِينَ": 4115, + "جَ": 4116, + "هذا": 4117, + "عد": 4118, + "الع": 4119, + "دْ": 4120, + "قَالَ": 4121, + "رُ": 4122, + "يم": 4123, + "ية": 4124, + "نُ": 4125, + "خَ": 4126, + "رب": 4127, + "الك": 4128, + "وَا": 4129, + "أنا": 4130, + "ةِ": 4131, + "الن": 4132, + "حد": 4133, + "عِ": 4134, + "تا": 4135, + "هو": 4136, + "فا": 4137, + "عا": 4138, + "الش": 4139, + "لُ": 4140, + "يت": 4141, + "ذَا": 4142, + "يع": 4143, + "الذ": 4144, + "حْ": 4145, + "الص": 4146, + "إِنَّ": 4147, + "جا": 4148, + "علي": 4149, + "كَا": 4150, + "بُ": 4151, + "تع": 4152, + "وق": 4153, + "مل": 4154, + "لَّ": 4155, + "يد": 4156, + "أخ": 4157, + "رف": 4158, + "تي": 4159, + "الِ": 4160, + "ّا": 4161, + "ذلك": 4162, + "أَنْ": 4163, + "سِ": 4164, + "توم": 4165, + "مر": 4166, + "مَنْ": 4167, + "بل": 4168, + "الق": 4169, + "الله": 4170, + "ِيَ": 4171, + "كم": 4172, + "ذَ": 4173, + "عل": 4174, + "حب": 4175, + "سي": 4176, + "عُ": 4177, + "الج": 4178, + "الد": 4179, + "شَ": 4180, + "تك": 4181, + "فْ": 4182, + "صَ": 4183, + "لل": 4184, + "دِ": 4185, + "بر": 4186, + "فِ": 4187, + "ته": 4188, + "أع": 4189, + "تْ": 4190, + "قْ": 4191, + "الْأَ": 4192, + "ئِ": 4193, + "عَنْ": 4194, + "ور": 4195, + "حا": 4196, + "الَّ": 4197, + "مت": 4198, + "فر": 4199, + "دُ": 4200, + "هنا": 4201, + "وَأَ": 4202, + "تب": 4203, + "ةُ": 4204, + "أي": 4205, + "سب": 4206, + "ريد": 4207, + "وج": 4208, + "كُمْ": 4209, + "حِ": 4210, + "كْ": 4211, + "در": 4212, + "َاء": 4213, + "هذه": 4214, + "الط": 4215, + "الْمُ": 4216, + "دة": 4217, + "قل": 4218, + "غَ": 4219, + "يوم": 4220, + "الَّذ": 4221, + "كر": 4222, + "تر": 4223, + "كِ": 4224, + "كي": 4225, + "عَلَى": 4226, + "رَب": 4227, + "عة": 4228, + "قُ": 4229, + "جْ": 4230, + "فض": 4231, + "لة": 4232, + "هْ": 4233, + "رَا": 4234, + "وَلَ": 4235, + "الْمَ": 4236, + "أَنَّ": 4237, + "يَا": 4238, + "أُ": 4239, + "شي": 4240, + "اللَّهُ": 4241, + "لَى": 4242, + "قِ": 4243, + "أت": 4244, + "عَلَيْ": 4245, + "اللَّهِ": 4246, + "الب": 4247, + "ضَ": 4248, + "ةً": 4249, + "قي": 4250, + "ار": 4251, + "بد": 4252, + "خْ": 4253, + "سْتَ": 4254, + "طَ": 4255, + "قَدْ": 4256, + "ذهب": 4257, + "أم": 4258, + "ماذا": 4259, + "وَإِ": 4260, + "ةٌ": 4261, + "ونَ": 4262, + "ليلى": 4263, + "ولا": 4264, + "حُ": 4265, + "هي": 4266, + "صل": 4267, + "الخ": 4268, + "ود": 4269, + "ليس": 4270, + "لدي": 4271, + "قال": 4272, + "كَانَ": 4273, + "مَّ": 4274, + "حي": 4275, + "تم": 4276, + "لن": 4277, + "وَلَا": 4278, + "بع": 4279, + "يمكن": 4280, + "سُ": 4281, + "ةَ": 4282, + "حت": 4283, + "رًا": 4284, + "كا": 4285, + "شا": 4286, + "هِمْ": 4287, + "لَهُ": 4288, + "زَ": 4289, + "داً": 4290, + "مس": 4291, + "كث": 4292, + "الْعَ": 4293, + "جِ": 4294, + "صْ": 4295, + "فَا": 4296, + "له": 4297, + "وي": 4298, + "عَا": 4299, + "هُوَ": 4300, + "بِي": 4301, + "بَا": 4302, + "أس": 4303, + "ثَ": 4304, + "لِي": 4305, + "رض": 4306, + "الرَّ": 4307, + "لِكَ": 4308, + "تَّ": 4309, + "فُ": 4310, + "قة": 4311, + "فعل": 4312, + "مِن": 4313, + "الآ": 4314, + "ثُ": 4315, + "سم": 4316, + "مَّا": 4317, + "بِهِ": 4318, + "تق": 4319, + "خر": 4320, + "لقد": 4321, + "خل": 4322, + "شر": 4323, + "أنت": 4324, + "لَّا": 4325, + "سن": 4326, + "السَّ": 4327, + "الذي": 4328, + "سَا": 4329, + "وما": 4330, + "زل": 4331, + "وب": 4332, + "أْ": 4333, + "إذا": 4334, + "رِي": 4335, + "حة": 4336, + "نِي": 4337, + "الْحَ": 4338, + "وَقَالَ": 4339, + "به": 4340, + "ةٍ": 4341, + "سأ": 4342, + "رٌ": 4343, + "بال": 4344, + "مة": 4345, + "شْ": 4346, + "وت": 4347, + "عند": 4348, + "فس": 4349, + "بَعْ": 4350, + "هر": 4351, + "قط": 4352, + "أح": 4353, + "إنه": 4354, + "وع": 4355, + "فت": 4356, + "غا": 4357, + "هناك": 4358, + "بت": 4359, + "مِنَ": 4360, + "سر": 4361, + "ذَلِكَ": 4362, + "رس": 4363, + "حدث": 4364, + "غْ": 4365, + "ِّي": 4366, + "الإ": 4367, + "وَيَ": 4368, + "جل": 4369, + "است": 4370, + "قِي": 4371, + "عب": 4372, + "وس": 4373, + "يش": 4374, + "الَّذِينَ": 4375, + "تاب": 4376, + "دِي": 4377, + "جب": 4378, + "كون": 4379, + "بن": 4380, + "الث": 4381, + "لَيْ": 4382, + "بعد": 4383, + "وَالْ": 4384, + "فَأَ": 4385, + "عم": 4386, + "هُم": 4387, + "تن": 4388, + "ذْ": 4389, + "أص": 4390, + "أين": 4391, + "رَبِّ": 4392, + "الذين": 4393, + "إِن": 4394, + "بين": 4395, + "جُ": 4396, + "عَلَيْهِ": 4397, + "حَا": 4398, + "لو": 4399, + "ستط": 4400, + "ظر": 4401, + "لَمْ": 4402, + "ءِ": 4403, + "كُل": 4404, + "طل": 4405, + "تَا": 4406, + "ضُ": 4407, + "كنت": 4408, + "لًا": 4409, + "مٌ": 4410, + "قبل": 4411, + "ــ": 4412, + "ذِ": 4413, + "قَوْ": 4414, + "صِ": 4415, + "مًا": 4416, + "كانت": 4417, + "صا": 4418, + "يق": 4419, + "الف": 4420, + "النا": 4421, + "مٍ": 4422, + "إِنْ": 4423, + "النَّ": 4424, + "جد": 4425, + "وَمَا": 4426, + "تت": 4427, + "بح": 4428, + "مكان": 4429, + "كيف": 4430, + "ّة": 4431, + "الا": 4432, + "جَا": 4433, + "أو": 4434, + "ساعد": 4435, + "ضِ": 4436, + "إلا": 4437, + "راً": 4438, + "قَا": 4439, + "رأ": 4440, + "عت": 4441, + "أحد": 4442, + "هد": 4443, + "ضا": 4444, + "طر": 4445, + "أق": 4446, + "ماء": 4447, + "دَّ": 4448, + "البا": 4449, + "مُو": 4450, + "أَوْ": 4451, + "طا": 4452, + "قُو": 4453, + "خِ": 4454, + "تل": 4455, + "ستطيع": 4456, + "دَا": 4457, + "النَّا": 4458, + "إلَى": 4459, + "وَتَ": 4460, + "هَذَا": 4461, + "بة": 4462, + "عليك": 4463, + "جر": 4464, + "المن": 4465, + "زا": 4466, + "رٍ": 4467, + "دع": 4468, + "ًّا": 4469, + "سة": 4470, + "ثُمَّ": 4471, + "شيء": 4472, + "الغ": 4473, + "تح": 4474, + "رُونَ": 4475, + "اليوم": 4476, + "مِي": 4477, + "نُوا": 4478, + "أر": 4479, + "تُمْ": 4480, + "عر": 4481, + "يف": 4482, + "أب": 4483, + "دًا": 4484, + "صَا": 4485, + "التَّ": 4486, + "أريد": 4487, + "الز": 4488, + "يَوْ": 4489, + "إلي": 4490, + "جي": 4491, + "يَعْ": 4492, + "فضل": 4493, + "الإن": 4494, + "أنه": 4495, + "1": 4496, + "2": 4497, + "3": 4498, + "4": 4499, + "5": 4500, + "·": 4501, + "×": 4502, + "̃": 4503, + "̌": 4504, + "ε": 4505, + "λ": 4506, + "μ": 4507, + "•": 4508, + "‧": 4509, + "─": 4510, + "□": 4511, + "、": 4512, + "。": 4513, + "〈": 4514, + "〉": 4515, + "《": 4516, + "》": 4517, + "「": 4518, + "」": 4519, + "『": 4520, + "』": 4521, + "ア": 4522, + "オ": 4523, + "カ": 4524, + "チ": 4525, + "ド": 4526, + "ベ": 4527, + "ャ": 4528, + "ヤ": 4529, + "ン": 4530, + "・": 4531, + "ー": 4532, + "ㄟ": 4533, + "!": 4534, + "(": 4535, + ")": 4536, + ",": 4537, + "-": 4538, + "/": 4539, + ":": 4540, + ";": 4541, + "?": 4542, + "p": 4543, + "i4": 4544, + "zh": 4545, + "i2": 4546, + "ng1": 4547, + "u4": 4548, + "i1": 4549, + "ng2": 4550, + "u3": 4551, + "de5": 4552, + "e4": 4553, + "i3": 4554, + "ng4": 4555, + "an4": 4556, + "shi4": 4557, + "an2": 4558, + "u2": 4559, + "u1": 4560, + "ng3": 4561, + "a1": 4562, + "an1": 4563, + "e2": 4564, + "a4": 4565, + "ei4": 4566, + "ong1": 4567, + "ai4": 4568, + "ao4": 4569, + "ang1": 4570, + "an3": 4571, + "wei4": 4572, + "uo2": 4573, + "n1": 4574, + "en2": 4575, + "ao3": 4576, + "e1": 4577, + "qi": 4578, + "eng2": 4579, + "zho": 4580, + "ang3": 4581, + "ang4": 4582, + "ang2": 4583, + "uo4": 4584, + "ge4": 4585, + "yi1": 4586, + "guo2": 4587, + "a3": 4588, + "he2": 4589, + "e3": 4590, + "yi2": 4591, + "di4": 4592, + "zhong1": 4593, + "bu4": 4594, + "ai2": 4595, + "n2": 4596, + "zai4": 4597, + "shi2": 4598, + "eng1": 4599, + "ren2": 4600, + "ong2": 4601, + "xian4": 4602, + "xu": 4603, + "n4": 4604, + "li4": 4605, + "en4": 4606, + "yu2": 4607, + "ei2": 4608, + "yi2ge4": 4609, + "ou4": 4610, + "ei3": 4611, + "ui4": 4612, + "a2": 4613, + "you3": 4614, + "ao1": 4615, + "da4": 4616, + "cheng2": 4617, + "en1": 4618, + "eng4": 4619, + "yi4": 4620, + "si1": 4621, + "zhi4": 4622, + "jia1": 4623, + "yuan2": 4624, + "ta1": 4625, + "de5yi2ge4": 4626, + "ke1": 4627, + "shu3": 4628, + "xi1": 4629, + "ji2": 4630, + "ao2": 4631, + "ou3": 4632, + "ong4": 4633, + "xia4": 4634, + "ai1": 4635, + "gong1": 4636, + "zhi1": 4637, + "en3": 4638, + "wei2": 4639, + "xue2": 4640, + "qu1": 4641, + "zhou1": 4642, + "er3": 4643, + "ming2": 4644, + "zhong3": 4645, + "li3": 4646, + "wu4": 4647, + "yi3": 4648, + "uo1": 4649, + "e5": 4650, + "ji4": 4651, + "xing2": 4652, + "jian4": 4653, + "hua4": 4654, + "yu3": 4655, + "uo3": 4656, + "ji1": 4657, + "ai3": 4658, + "zuo4": 4659, + "hou4": 4660, + "hui4": 4661, + "ei1": 4662, + "nian2": 4663, + "qi2": 4664, + "dao4": 4665, + "sheng1": 4666, + "de2": 4667, + "dai4": 4668, + "uan2": 4669, + "zhe4": 4670, + "zheng4": 4671, + "ben3": 4672, + "shang4": 4673, + "zhu3": 4674, + "bei4": 4675, + "ye4": 4676, + "chu1": 4677, + "zhan4": 4678, + "le5": 4679, + "lai2": 4680, + "shi3": 4681, + "nan2": 4682, + "ren4": 4683, + "you2": 4684, + "ke4": 4685, + "ba1": 4686, + "fu4": 4687, + "dui4": 4688, + "ya4": 4689, + "mei3": 4690, + "zi4": 4691, + "xin1": 4692, + "jing1": 4693, + "zhu": 4694, + "n3": 4695, + "yong4": 4696, + "mu4": 4697, + "jiao4": 4698, + "ye3": 4699, + "jin4": 4700, + "bian4": 4701, + "lu4": 4702, + "qi1": 4703, + "she4": 4704, + "xiang1": 4705, + "ong3": 4706, + "shu4": 4707, + "dong4": 4708, + "suo3": 4709, + "guan1": 4710, + "san1": 4711, + "te4": 4712, + "duo1": 4713, + "fu2": 4714, + "min2": 4715, + "la1": 4716, + "zhi2": 4717, + "zhen4": 4718, + "ou1": 4719, + "wu3": 4720, + "ma3": 4721, + "i5": 4722, + "zi5": 4723, + "ju4": 4724, + "er4": 4725, + "yao4": 4726, + "xia4de5yi2ge4": 4727, + "si4": 4728, + "tu2": 4729, + "shan1": 4730, + "zui4": 4731, + "yin1": 4732, + "er2": 4733, + "tong2": 4734, + "dong1": 4735, + "yu4": 4736, + "yan2": 4737, + "qian2": 4738, + "shu3xia4de5yi2ge4": 4739, + "jun1": 4740, + "ke3": 4741, + "wen2": 4742, + "fa3": 4743, + "luo2": 4744, + "zhu4": 4745, + "xi4": 4746, + "kou3": 4747, + "bei3": 4748, + "jian1": 4749, + "fa1": 4750, + "dian4": 4751, + "jiang1": 4752, + "wei4yu2": 4753, + "xiang4": 4754, + "zhi3": 4755, + "eng3": 4756, + "fang1": 4757, + "lan2": 4758, + "shu": 4759, + "ri4": 4760, + "lian2": 4761, + "shou3": 4762, + "qiu2": 4763, + "jin1": 4764, + "huo4": 4765, + "shu3xia4de5yi2ge4zhong3": 4766, + "fen1": 4767, + "nei4": 4768, + "gai1": 4769, + "mei3guo2": 4770, + "un2": 4771, + "ge2": 4772, + "bao3": 4773, + "qing1": 4774, + "gao1": 4775, + "tai2": 4776, + "xiao3": 4777, + "jie2": 4778, + "tian1": 4779, + "chang2": 4780, + "quan2": 4781, + "lie4": 4782, + "hai3": 4783, + "fei1": 4784, + "ti3": 4785, + "jue2": 4786, + "ou2": 4787, + "ci3": 4788, + "zu2": 4789, + "ni2": 4790, + "biao3": 4791, + "zhong1guo2": 4792, + "du4": 4793, + "yue4": 4794, + "xing4": 4795, + "sheng4": 4796, + "che1": 4797, + "dan1": 4798, + "jie1": 4799, + "lin2": 4800, + "ping2": 4801, + "fu3": 4802, + "gu3": 4803, + "jie4": 4804, + "v3": 4805, + "sheng3": 4806, + "na4": 4807, + "yuan4": 4808, + "zhang3": 4809, + "guan3": 4810, + "dao3": 4811, + "zu3": 4812, + "ding4": 4813, + "dian3": 4814, + "ceng2": 4815, + "ren2kou3": 4816, + "tai4": 4817, + "tong1": 4818, + "guo4": 4819, + "neng2": 4820, + "chang3": 4821, + "hua2": 4822, + "liu2": 4823, + "ying1": 4824, + "xiao4": 4825, + "ci4": 4826, + "bian4hua4": 4827, + "liang3": 4828, + "gong4": 4829, + "zhong4": 4830, + "de5yi1": 4831, + "se4": 4832, + "kai1": 4833, + "wang2": 4834, + "jiu4": 4835, + "shi1": 4836, + "shou4": 4837, + "mei2": 4838, + "feng1": 4839, + "ze2": 4840, + "tu2shi4": 4841, + "ti2": 4842, + "qi4": 4843, + "jiu3": 4844, + "shen1": 4845, + "zhe3": 4846, + "ren2kou3bian4hua4": 4847, + "ren2kou3bian4hua4tu2shi4": 4848, + "di4qu1": 4849, + "yang2": 4850, + "men5": 4851, + "long2": 4852, + "bing4": 4853, + "chan3": 4854, + "zhu1": 4855, + "wei3": 4856, + "wai4": 4857, + "xing1": 4858, + "bo1": 4859, + "bi3": 4860, + "tang2": 4861, + "hua1": 4862, + "bo2": 4863, + "shui3": 4864, + "shu1": 4865, + "dou1": 4866, + "sai4": 4867, + "chao2": 4868, + "bi4": 4869, + "ling2": 4870, + "lei4": 4871, + "da4xue2": 4872, + "fen4": 4873, + "shu3de5": 4874, + "mu3": 4875, + "jiao1": 4876, + "dang1": 4877, + "cheng1": 4878, + "tong3": 4879, + "nv3": 4880, + "qi3": 4881, + "yan3": 4882, + "mian4": 4883, + "luo4": 4884, + "jing4": 4885, + "ge1": 4886, + "ru4": 4887, + "dan4": 4888, + "ri4ben3": 4889, + "pu3": 4890, + "yun4": 4891, + "huang2": 4892, + "wo3": 4893, + "lv": 4894, + "hai2": 4895, + "shi4yi1": 4896, + "xie1": 4897, + "ying3": 4898, + "wu2": 4899, + "shen2": 4900, + "wang3": 4901, + "guang3": 4902, + "liu4": 4903, + "su4": 4904, + "shi4zhen4": 4905, + "can1": 4906, + "cao3": 4907, + "xia2": 4908, + "ka3": 4909, + "da2": 4910, + "hu4": 4911, + "ban4": 4912, + "dang3": 4913, + "hu2": 4914, + "zong3": 4915, + "deng3": 4916, + "de5yi2ge4shi4zhen4": 4917, + "chuan2": 4918, + "mo4": 4919, + "zhang1": 4920, + "ban1": 4921, + "mo2": 4922, + "cha2": 4923, + "ce4": 4924, + "zhu3yao4": 4925, + "tou2": 4926, + "ju2": 4927, + "shi4wei4yu2": 4928, + "sa4": 4929, + "un1": 4930, + "ke3yi3": 4931, + "du1": 4932, + "han4": 4933, + "liang4": 4934, + "sha1": 4935, + "jia3": 4936, + "zi1": 4937, + "lv4": 4938, + "fu1": 4939, + "xian1": 4940, + "xu4": 4941, + "guang1": 4942, + "meng2": 4943, + "bao4": 4944, + "you4": 4945, + "rong2": 4946, + "zhi1yi1": 4947, + "wei1": 4948, + "mao2": 4949, + "guo2jia1": 4950, + "cong2": 4951, + "gou4": 4952, + "tie3": 4953, + "zhen1": 4954, + "du2": 4955, + "bian1": 4956, + "ci2": 4957, + "qu3": 4958, + "fan4": 4959, + "xiang3": 4960, + "men2": 4961, + "ju1": 4962, + "hong2": 4963, + "zi3": 4964, + "ta1men5": 4965, + "ji3": 4966, + "zong1": 4967, + "zhou1de5yi2ge4shi4zhen4": 4968, + "tuan2": 4969, + "jing3": 4970, + "gong1si1": 4971, + "xie4": 4972, + "li2": 4973, + "li4shi3": 4974, + "bao1": 4975, + "gang3": 4976, + "gui1": 4977, + "zheng1": 4978, + "zhi2wu4": 4979, + "ta1de5": 4980, + "pin3": 4981, + "zhuan1": 4982, + "chong2": 4983, + "shi3yong4": 4984, + "wa3": 4985, + "shuo1": 4986, + "chuan1": 4987, + "lei2": 4988, + "wan1": 4989, + "huo2": 4990, + "su1": 4991, + "zao3": 4992, + "gai3": 4993, + "qu4": 4994, + "gu4": 4995, + "xi2": 4996, + "hang2": 4997, + "ying4": 4998, + "cun1": 4999, + "gen1": 5000, + "ying2": 5001, + "ting2": 5002, + "cheng2shi4": 5003, + "jiang3": 5004, + "ling3": 5005, + "lun2": 5006, + "bu4fen4": 5007, + "deng1": 5008, + "xuan3": 5009, + "dong4wu4": 5010, + "de2guo2": 5011, + "xian3": 5012, + "fan3": 5013, + "zhe5": 5014, + "han2": 5015, + "hao4": 5016, + "mi4": 5017, + "ran2": 5018, + "qin1": 5019, + "tiao2": 5020, + "zhan3": 5021, + "[ar]": 5022, + "[zh-cn]": 5023, + "¡": 5024, + "é": 5025, + "shi": 5026, + "tsu": 5027, + "teki": 5028, + "nai": 5029, + "aru": 5030, + "uu": 5031, + "kai": 5032, + "shite": 5033, + "mono": 5034, + "koto": 5035, + "kara": 5036, + "shita": 5037, + "suru": 5038, + "masu": 5039, + "tai": 5040, + "ware": 5041, + "shin": 5042, + "oku": 5043, + "yuu": 5044, + "iru": 5045, + "jiko": 5046, + "desu": 5047, + "rare": 5048, + "shou": 5049, + "sha": 5050, + "sekai": 5051, + "kyou": 5052, + "mashita": 5053, + "nara": 5054, + "kei": 5055, + "ita": 5056, + "ari": 5057, + "itsu": 5058, + "kono": 5059, + "naka": 5060, + "chou": 5061, + "sore": 5062, + "naru": 5063, + "gaku": 5064, + "reba": 5065, + "hito": 5066, + "sai": 5067, + "nan": 5068, + "dai": 5069, + "tsuku": 5070, + "shiki": 5071, + "sare": 5072, + "naku": 5073, + "jun": 5074, + "kaku": 5075, + "zai": 5076, + "wata": 5077, + "shuu": 5078, + "ii": 5079, + "kare": 5080, + "shii": 5081, + "made": 5082, + "sho": 5083, + "kereba": 5084, + "shika": 5085, + "ichi": 5086, + "deki": 5087, + "nin": 5088, + "wareware": 5089, + "nakereba": 5090, + "oite": 5091, + "yaku": 5092, + "mujun": 5093, + "yoku": 5094, + "butsu": 5095, + "omo": 5096, + "gae": 5097, + "naranai": 5098, + "tachi": 5099, + "chuu": 5100, + "kangae": 5101, + "toki": 5102, + "koro": 5103, + "mujunteki": 5104, + "naga": 5105, + "jin": 5106, + "shima": 5107, + "iku": 5108, + "imasu": 5109, + "hon": 5110, + "kae": 5111, + "kore": 5112, + "kita": 5113, + "datta": 5114, + "jitsu": 5115, + "mae": 5116, + "toku": 5117, + "douitsu": 5118, + "ritsu": 5119, + "kyuu": 5120, + "hyou": 5121, + "rareta": 5122, + "keisei": 5123, + "kkan": 5124, + "rareru": 5125, + "mou": 5126, + "doko": 5127, + "ryou": 5128, + "dake": 5129, + "nakatta": 5130, + "soko": 5131, + "tabe": 5132, + "hana": 5133, + "fuku": 5134, + "yasu": 5135, + "wataku": 5136, + "yama": 5137, + "kyo": 5138, + "genzai": 5139, + "boku": 5140, + "ata": 5141, + "kawa": 5142, + "masen": 5143, + "juu": 5144, + "natte": 5145, + "watakushi": 5146, + "yotte": 5147, + "hai": 5148, + "jishin": 5149, + "rete": 5150, + "oka": 5151, + "kagaku": 5152, + "natta": 5153, + "karu": 5154, + "nari": 5155, + "mata": 5156, + "kuru": 5157, + "gai": 5158, + "kari": 5159, + "shakai": 5160, + "koui": 5161, + "yori": 5162, + "setsu": 5163, + "reru": 5164, + "tokoro": 5165, + "jutsu": 5166, + "saku": 5167, + "ttai": 5168, + "ningen": 5169, + "tame": 5170, + "kankyou": 5171, + "ooku": 5172, + "watashi": 5173, + "tsukuru": 5174, + "sugi": 5175, + "jibun": 5176, + "shitsu": 5177, + "keru": 5178, + "kishi": 5179, + "shikashi": 5180, + "moto": 5181, + "mari": 5182, + "itte": 5183, + "deshita": 5184, + "nde": 5185, + "arimasu": 5186, + "koe": 5187, + "zettai": 5188, + "kkanteki": 5189, + "rekishi": 5190, + "dekiru": 5191, + "tsuka": 5192, + "itta": 5193, + "kobutsu": 5194, + "miru": 5195, + "shoku": 5196, + "shimasu": 5197, + "gijutsu": 5198, + "gyou": 5199, + "joushiki": 5200, + "atta": 5201, + "hodo": 5202, + "koko": 5203, + "tsukurareta": 5204, + "zoku": 5205, + "hitei": 5206, + "koku": 5207, + "rekishiteki": 5208, + "kete": 5209, + "kako": 5210, + "nagara": 5211, + "kakaru": 5212, + "shutai": 5213, + "haji": 5214, + "taku": 5215, + "douitsuteki": 5216, + "mete": 5217, + "tsuu": 5218, + "sarete": 5219, + "genjitsu": 5220, + "bai": 5221, + "nawa": 5222, + "jikan": 5223, + "waru": 5224, + "rt": 5225, + "atsu": 5226, + "soku": 5227, + "kouiteki": 5228, + "kata": 5229, + "tetsu": 5230, + "gawa": 5231, + "kedo": 5232, + "reta": 5233, + "sayou": 5234, + "tteru": 5235, + "tori": 5236, + "kimi": 5237, + "mura": 5238, + "sareru": 5239, + "machi": 5240, + "kya": 5241, + "osa": 5242, + "konna": 5243, + "aku": 5244, + "sareta": 5245, + "ipp": 5246, + "shiku": 5247, + "uchi": 5248, + "hitotsu": 5249, + "hatara": 5250, + "tachiba": 5251, + "shiro": 5252, + "katachi": 5253, + "tomo": 5254, + "ete": 5255, + "meru": 5256, + "nichi": 5257, + "dare": 5258, + "katta": 5259, + "eru": 5260, + "suki": 5261, + "ooki": 5262, + "maru": 5263, + "moku": 5264, + "oko": 5265, + "kangaerareru": 5266, + "oto": 5267, + "tanni": 5268, + "tada": 5269, + "taiteki": 5270, + "motte": 5271, + "kinou": 5272, + "shinai": 5273, + "kki": 5274, + "tari": 5275, + "ranai": 5276, + "kkou": 5277, + "mirai": 5278, + "ppon": 5279, + "goto": 5280, + "hitsu": 5281, + "teru": 5282, + "mochi": 5283, + "katsu": 5284, + "nyuu": 5285, + "zuka": 5286, + "tsuite": 5287, + "nomi": 5288, + "sugu": 5289, + "kuda": 5290, + "tetsugaku": 5291, + "ika": 5292, + "ronri": 5293, + "oki": 5294, + "nippon": 5295, + "shimashita": 5296, + "chishiki": 5297, + "chokkanteki": 5298, + "suko": 5299, + "kuu": 5300, + "arou": 5301, + "katte": 5302, + "kuri": 5303, + "inai": 5304, + "hyougen": 5305, + "ishiki": 5306, + "doku": 5307, + "atte": 5308, + "atara": 5309, + "wari": 5310, + "kao": 5311, + "seisan": 5312, + "hanashi": 5313, + "kake": 5314, + "naji": 5315, + "sunawa": 5316, + "sunawachi": 5317, + "ugo": 5318, + "suu": 5319, + "bara": 5320, + "hiro": 5321, + "iwa": 5322, + "betsu": 5323, + "yoi": 5324, + "seru": 5325, + "shiteru": 5326, + "rarete": 5327, + "toshi": 5328, + "seki": 5329, + "tairitsu": 5330, + "wakara": 5331, + "tokyo": 5332, + "kka": 5333, + "kyoku": 5334, + "iro": 5335, + "mite": 5336, + "saki": 5337, + "kanji": 5338, + "mita": 5339, + "sube": 5340, + "ryoku": 5341, + "matta": 5342, + "kudasai": 5343, + "omoi": 5344, + "wareru": 5345, + "hitsuyou": 5346, + "kashi": 5347, + "renai": 5348, + "kankei": 5349, + "gatte": 5350, + "ochi": 5351, + "motsu": 5352, + "sonzai": 5353, + "taishite": 5354, + "ame": 5355, + "seimei": 5356, + "kano": 5357, + "giri": 5358, + "kangaeru": 5359, + "yue": 5360, + "asa": 5361, + "onaji": 5362, + "yoru": 5363, + "niku": 5364, + "osaka": 5365, + "sukoshi": 5366, + "tama": 5367, + "kanojo": 5368, + "kite": 5369, + "mondai": 5370, + "amari": 5371, + "eki": 5372, + "kojin": 5373, + "haya": 5374, + "dete": 5375, + "atarashii": 5376, + "awa": 5377, + "gakkou": 5378, + "tsuzu": 5379, + "shukan": 5380, + "imashita": 5381, + "atae": 5382, + "darou": 5383, + "hataraku": 5384, + "gata": 5385, + "dachi": 5386, + "matsu": 5387, + "arimasen": 5388, + "seibutsu": 5389, + "mitsu": 5390, + "heya": 5391, + "yasui": 5392, + "deni": 5393, + "noko": 5394, + "haha": 5395, + "domo": 5396, + "kami": 5397, + "sudeni": 5398, + "nao": 5399, + "raku": 5400, + "ike": 5401, + "meta": 5402, + "kodomo": 5403, + "soshite": 5404, + "game": 5405, + "bakari": 5406, + "tote": 5407, + "hatsu": 5408, + "mise": 5409, + "mokuteki": 5410, + "dakara": 5411, + "[ja]": 5412, + "ő": 5413, + "ű": 5414, + "そ": 5415, + "な": 5416, + "ん": 5417, + "포": 5418, + "�": 5419, + "gy": 5420, + "eg": 5421, + "cs": 5422, + "ál": 5423, + "egy": 5424, + "át": 5425, + "ott": 5426, + "ett": 5427, + "meg": 5428, + "hogy": 5429, + "ég": 5430, + "ól": 5431, + "nek": 5432, + "volt": 5433, + "ág": 5434, + "nk": 5435, + "ék": 5436, + "ít": 5437, + "ák": 5438, + "ud": 5439, + "szer": 5440, + "mind": 5441, + "oz": 5442, + "ép": 5443, + "ért": 5444, + "mond": 5445, + "szt": 5446, + "nak": 5447, + "ől": 5448, + "csak": 5449, + "oly": 5450, + "áll": 5451, + "ány": 5452, + "mint": 5453, + "már": 5454, + "ött": 5455, + "nagy": 5456, + "ész": 5457, + "azt": 5458, + "elő": 5459, + "tud": 5460, + "ény": 5461, + "áz": 5462, + "még": 5463, + "köz": 5464, + "ely": 5465, + "ség": 5466, + "hoz": 5467, + "uk": 5468, + "kez": 5469, + "ám": 5470, + "aj": 5471, + "unk": 5472, + "vagy": 5473, + "szem": 5474, + "ember": 5475, + "fog": 5476, + "mert": 5477, + "ös": 5478, + "ság": 5479, + "leg": 5480, + "ünk": 5481, + "hát": 5482, + "ony": 5483, + "ezt": 5484, + "minden": 5485, + "ült": 5486, + "jó": 5487, + "kis": 5488, + "áj": 5489, + "úgy": 5490, + "most": 5491, + "ír": 5492, + "itt": 5493, + "elt": 5494, + "mondta": 5495, + "kell": 5496, + "ált": 5497, + "érd": 5498, + "tö": 5499, + "vár": 5500, + "lát": 5501, + "ők": 5502, + "vet": 5503, + "után": 5504, + "két": 5505, + "nap": 5506, + "ív": 5507, + "ály": 5508, + "vég": 5509, + "ök": 5510, + "dul": 5511, + "néz": 5512, + "ában": 5513, + "kül": 5514, + "akkor": 5515, + "szél": 5516, + "új": 5517, + "olyan": 5518, + "ked": 5519, + "hely": 5520, + "tör": 5521, + "ból": 5522, + "elm": 5523, + "ára": 5524, + "ló": 5525, + "volna": 5526, + "lehet": 5527, + "ebb": 5528, + "sok": 5529, + "olt": 5530, + "eket": 5531, + "bor": 5532, + "fej": 5533, + "gond": 5534, + "akar": 5535, + "fél": 5536, + "úl": 5537, + "otta": 5538, + "valami": 5539, + "jel": 5540, + "éd": 5541, + "arc": 5542, + "hall": 5543, + "föl": 5544, + "ába": 5545, + "olg": 5546, + "kir": 5547, + "old": 5548, + "kérd": 5549, + "jár": 5550, + "úr": 5551, + "zs": 5552, + "élet": 5553, + "ját": 5554, + "ov": 5555, + "éz": 5556, + "vil": 5557, + "őr": 5558, + "ög": 5559, + "lesz": 5560, + "koz": 5561, + "ább": 5562, + "király": 5563, + "eng": 5564, + "igaz": 5565, + "haj": 5566, + "kod": 5567, + "ról": 5568, + "több": 5569, + "szó": 5570, + "ében": 5571, + "öt": 5572, + "nyi": 5573, + "szól": 5574, + "gondol": 5575, + "egész": 5576, + "így": 5577, + "ős": 5578, + "obb": 5579, + "osan": 5580, + "ből": 5581, + "abb": 5582, + "őt": 5583, + "nál": 5584, + "kép": 5585, + "aztán": 5586, + "tart": 5587, + "beszél": 5588, + "előtt": 5589, + "aszt": 5590, + "maj": 5591, + "kör": 5592, + "hang": 5593, + "íz": 5594, + "incs": 5595, + "év": 5596, + "ód": 5597, + "ók": 5598, + "hozz": 5599, + "okat": 5600, + "nagyon": 5601, + "ház": 5602, + "ped": 5603, + "ezte": 5604, + "etlen": 5605, + "neki": 5606, + "majd": 5607, + "szony": 5608, + "ának": 5609, + "felé": 5610, + "egyszer": 5611, + "adt": 5612, + "gyer": 5613, + "amikor": 5614, + "foly": 5615, + "szak": 5616, + "őd": 5617, + "hú": 5618, + "ász": 5619, + "amely": 5620, + "ére": 5621, + "ilyen": 5622, + "oda": 5623, + "ják": 5624, + "tár": 5625, + "ával": 5626, + "lak": 5627, + "gyan": 5628, + "ély": 5629, + "út": 5630, + "kezd": 5631, + "mell": 5632, + "mikor": 5633, + "hez": 5634, + "való": 5635, + "szeret": 5636, + "rend": 5637, + "vissza": 5638, + "fő": 5639, + "asszony": 5640, + "ről": 5641, + "pedig": 5642, + "szép": 5643, + "ták": 5644, + "öv": 5645, + "világ": 5646, + "maga": 5647, + "szik": 5648, + "éj": 5649, + "ént": 5650, + "jött": 5651, + "szí": 5652, + "gat": 5653, + "ettem": 5654, + "hány": 5655, + "ást": 5656, + "ahol": 5657, + "őket": 5658, + "hár": 5659, + "nő": 5660, + "csi": 5661, + "talál": 5662, + "elte": 5663, + "látt": 5664, + "tört": 5665, + "hagy": 5666, + "esz": 5667, + "nél": 5668, + "kut": 5669, + "lány": 5670, + "amit": 5671, + "ső": 5672, + "ellen": 5673, + "magát": 5674, + "ugyan": 5675, + "külön": 5676, + "asz": 5677, + "mindig": 5678, + "lép": 5679, + "talán": 5680, + "szor": 5681, + "illan": 5682, + "nincs": 5683, + "vagyok": 5684, + "telen": 5685, + "ismer": 5686, + "isten": 5687, + "ított": 5688, + "jobb": 5689, + "ves": 5690, + "dult": 5691, + "juk": 5692, + "szen": 5693, + "öm": 5694, + "lett": 5695, + "egyik": 5696, + "bár": 5697, + "szi": 5698, + "szív": 5699, + "azon": 5700, + "eszt": 5701, + "föld": 5702, + "kuty": 5703, + "pillan": 5704, + "fér": 5705, + "től": 5706, + "tű": 5707, + "ébe": 5708, + "tött": 5709, + "barát": 5710, + "íg": 5711, + "ahogy": 5712, + "eh": 5713, + "ep": 5714, + "jelent": 5715, + "tat": 5716, + "szeg": 5717, + "mintha": 5718, + "egyen": 5719, + "szab": 5720, + "bizony": 5721, + "jon": 5722, + "öreg": 5723, + "dolg": 5724, + "csap": 5725, + "tiszt": 5726, + "állt": 5727, + "ancs": 5728, + "idő": 5729, + "ügy": 5730, + "miért": 5731, + "ót": 5732, + "csin": 5733, + "ének": 5734, + "vér": 5735, + "jól": 5736, + "alatt": 5737, + "mely": 5738, + "semmi": 5739, + "nyug": 5740, + "vág": 5741, + "követ": 5742, + "össze": 5743, + "mad": 5744, + "acs": 5745, + "fiú": 5746, + "másik": 5747, + "jön": 5748, + "szám": 5749, + "rész": 5750, + "kér": 5751, + "ével": 5752, + "[hu]": 5753, + "%": 5754, + "0": 5755, + "6": 5756, + "7": 5757, + "8": 5758, + "9": 5759, + "A": 5760, + "B": 5761, + "C": 5762, + "D": 5763, + "E": 5764, + "F": 5765, + "G": 5766, + "H": 5767, + "I": 5768, + "J": 5769, + "K": 5770, + "L": 5771, + "M": 5772, + "N": 5773, + "O": 5774, + "P": 5775, + "Q": 5776, + "R": 5777, + "S": 5778, + "T": 5779, + "U": 5780, + "V": 5781, + "W": 5782, + "X": 5783, + "Y": 5784, + "Z": 5785, + "Ł": 5786, + "α": 5787, + "ς": 5788, + "♥": 5789, + "か": 5790, + "ズ": 5791, + "因": 5792, + "国": 5793, + "怎": 5794, + "抱": 5795, + "推": 5796, + "有": 5797, + "樣": 5798, + "為": 5799, + "群": 5800, + "麼": 5801, + "eo": 5802, + "eul": 5803, + "eun": 5804, + "eon": 5805, + "ae": 5806, + "yeon": 5807, + "yeo": 5808, + "ui": 5809, + "hae": 5810, + "geo": 5811, + "neun": 5812, + "ssda": 5813, + "seo": 5814, + "eong": 5815, + "kk": 5816, + "jeo": 5817, + "deul": 5818, + "eum": 5819, + "yeong": 5820, + "geos": 5821, + "hag": 5822, + "aneun": 5823, + "iss": 5824, + "dae": 5825, + "eob": 5826, + "eol": 5827, + "geu": 5828, + "jeong": 5829, + "sae": 5830, + "doe": 5831, + "geul": 5832, + "eulo": 5833, + "bn": 5834, + "sang": 5835, + "bnida": 5836, + "haneun": 5837, + "jeog": 5838, + "saeng": 5839, + "ineun": 5840, + "anh": 5841, + "salam": 5842, + "eom": 5843, + "nae": 5844, + "gwa": 5845, + "yeol": 5846, + "eseo": 5847, + "myeon": 5848, + "ttae": 5849, + "hw": 5850, + "eobs": 5851, + "jang": 5852, + "gw": 5853, + "ileul": 5854, + "yeog": 5855, + "jeon": 5856, + "sig": 5857, + "jag": 5858, + "hago": 5859, + "deun": 5860, + "seong": 5861, + "gag": 5862, + "ham": 5863, + "dang": 5864, + "leul": 5865, + "sil": 5866, + "dong": 5867, + "handa": 5868, + "eossda": 5869, + "aeg": 5870, + "seon": 5871, + "haessda": 5872, + "issda": 5873, + "ege": 5874, + "mul": 5875, + "jung": 5876, + "jig": 5877, + "issneun": 5878, + "geun": 5879, + "seubnida": 5880, + "won": 5881, + "daneun": 5882, + "eoh": 5883, + "deo": 5884, + "gam": 5885, + "jal": 5886, + "haeng": 5887, + "yang": 5888, + "bang": 5889, + "jae": 5890, + "saenggag": 5891, + "hage": 5892, + "sog": 5893, + "eoss": 5894, + "jasin": 5895, + "jil": 5896, + "eog": 5897, + "gyeong": 5898, + "gong": 5899, + "deon": 5900, + "haess": 5901, + "eung": 5902, + "joh": 5903, + "nal": 5904, + "myeong": 5905, + "eona": 5906, + "igo": 5907, + "gyeol": 5908, + "yag": 5909, + "gwan": 5910, + "uli": 5911, + "yong": 5912, + "lyeo": 5913, + "jog": 5914, + "eohge": 5915, + "bog": 5916, + "tong": 5917, + "manh": 5918, + "jeol": 5919, + "geol": 5920, + "aga": 5921, + "naneun": 5922, + "uneun": 5923, + "cheol": 5924, + "dol": 5925, + "bad": 5926, + "hamyeon": 5927, + "yeossda": 5928, + "ibnida": 5929, + "gye": 5930, + "eos": 5931, + "hwal": 5932, + "salamdeul": 5933, + "jiman": 5934, + "dangsin": 5935, + "jib": 5936, + "ttaemun": 5937, + "ib": 5938, + "eneun": 5939, + "eug": 5940, + "jeom": 5941, + "geuleon": 5942, + "hwa": 5943, + "assda": 5944, + "beob": 5945, + "bae": 5946, + "yeoss": 5947, + "chin": 5948, + "chaeg": 5949, + "geon": 5950, + "naega": 5951, + "iga": 5952, + "sigan": 5953, + "gil": 5954, + "hyeon": 5955, + "lyeog": 5956, + "gug": 5957, + "pyeon": 5958, + "wae": 5959, + "jul": 5960, + "seul": 5961, + "deung": 5962, + "hajiman": 5963, + "eumyeon": 5964, + "pil": 5965, + "nyeon": 5966, + "tae": 5967, + "pyo": 5968, + "jineun": 5969, + "beon": 5970, + "hada": 5971, + "seol": 5972, + "sip": 5973, + "daleun": 5974, + "salm": 5975, + "gyo": 5976, + "cheon": 5977, + "hagi": 5978, + "cheoleom": 5979, + "gal": 5980, + "ila": 5981, + "kkaji": 5982, + "anhneun": 5983, + "habnida": 5984, + "tteon": 5985, + "haeseo": 5986, + "doenda": 5987, + "ttal": 5988, + "ilo": 5989, + "seub": 5990, + "byeon": 5991, + "myeo": 5992, + "beol": 5993, + "jeung": 5994, + "chim": 5995, + "hwang": 5996, + "euneun": 5997, + "jong": 5998, + "boda": 5999, + "nol": 6000, + "neom": 6001, + "buteo": 6002, + "jigeum": 6003, + "eobsda": 6004, + "daelo": 6005, + "yul": 6006, + "pyeong": 6007, + "seoneun": 6008, + "salang": 6009, + "seut": 6010, + "heom": 6011, + "hyang": 6012, + "gwang": 6013, + "eobsneun": 6014, + "hwag": 6015, + "gess": 6016, + "jagi": 6017, + "ileon": 6018, + "wihae": 6019, + "daehan": 6020, + "gaji": 6021, + "meog": 6022, + "jyeo": 6023, + "chaj": 6024, + "byeong": 6025, + "eod": 6026, + "gyeo": 6027, + "eoji": 6028, + "gul": 6029, + "modeun": 6030, + "insaeng": 6031, + "geulae": 6032, + "sasil": 6033, + "sib": 6034, + "chal": 6035, + "ilago": 6036, + "geum": 6037, + "doeneun": 6038, + "bol": 6039, + "gajang": 6040, + "geuligo": 6041, + "hyeong": 6042, + "haengbog": 6043, + "chul": 6044, + "chae": 6045, + "mang": 6046, + "dam": 6047, + "choe": 6048, + "sijag": 6049, + "cheong": 6050, + "ilaneun": 6051, + "ulineun": 6052, + "aen": 6053, + "kke": 6054, + "munje": 6055, + "teu": 6056, + "geuneun": 6057, + "bge": 6058, + "cheo": 6059, + "baeg": 6060, + "jug": 6061, + "sangdae": 6062, + "geugeos": 6063, + "dog": 6064, + "eus": 6065, + "jab": 6066, + "hyeo": 6067, + "tteohge": 6068, + "chil": 6069, + "swi": 6070, + "jileul": 6071, + "chang": 6072, + "ganeun": 6073, + "iji": 6074, + "dago": 6075, + "yohan": 6076, + "teug": 6077, + "ppun": 6078, + "aleul": 6079, + "haengdong": 6080, + "sesang": 6081, + "edo": 6082, + "mandeul": 6083, + "amyeon": 6084, + "kkae": 6085, + "bag": 6086, + "ideul": 6087, + "pum": 6088, + "meol": 6089, + "neul": 6090, + "hamkke": 6091, + "chung": 6092, + "dab": 6093, + "yug": 6094, + "sag": 6095, + "gwangye": 6096, + "ileohge": 6097, + "balo": 6098, + "neunde": 6099, + "hamyeo": 6100, + "geuleoh": 6101, + "anila": 6102, + "bangbeob": 6103, + "dasi": 6104, + "byeol": 6105, + "gyeon": 6106, + "gamjeong": 6107, + "oneul": 6108, + "janeun": 6109, + "yeom": 6110, + "lago": 6111, + "igi": 6112, + "hwan": 6113, + "teul": 6114, + "eoseo": 6115, + "sik": 6116, + "jaga": 6117, + "geuleom": 6118, + "geuleona": 6119, + "jeongdo": 6120, + "gyeog": 6121, + "geuleohge": 6122, + "geudeul": 6123, + "eut": 6124, + "imyeon": 6125, + "jjae": 6126, + "keun": 6127, + "isang": 6128, + "malhaessda": 6129, + "euge": 6130, + "nop": 6131, + "ingan": 6132, + "bomyeon": 6133, + "taeg": 6134, + "dwi": 6135, + "saneun": 6136, + "wan": 6137, + "anhgo": 6138, + "nugu": 6139, + "sung": 6140, + "damyeon": 6141, + "adeul": 6142, + "peul": 6143, + "ttala": 6144, + "geosdo": 6145, + "aji": 6146, + "meon": 6147, + "eumyeo": 6148, + "dolog": 6149, + "neung": 6150, + "modu": 6151, + "[ko]": 6152, + "\u0014": 6153, + "\u0016": 6154, + "$": 6155, + "*": 6156, + "|": 6157, + "°": 6158, + "º": 6159, + "ँ": 6160, + "ं": 6161, + "ः": 6162, + "अ": 6163, + "आ": 6164, + "इ": 6165, + "ई": 6166, + "उ": 6167, + "ऊ": 6168, + "ऋ": 6169, + "ऎ": 6170, + "ए": 6171, + "ऐ": 6172, + "ऑ": 6173, + "ऒ": 6174, + "ओ": 6175, + "औ": 6176, + "क": 6177, + "ख": 6178, + "ग": 6179, + "घ": 6180, + "ङ": 6181, + "च": 6182, + "छ": 6183, + "ज": 6184, + "झ": 6185, + "ञ": 6186, + "ट": 6187, + "ठ": 6188, + "ड": 6189, + "ढ": 6190, + "ण": 6191, + "त": 6192, + "थ": 6193, + "द": 6194, + "ध": 6195, + "न": 6196, + "ऩ": 6197, + "प": 6198, + "फ": 6199, + "ब": 6200, + "भ": 6201, + "म": 6202, + "य": 6203, + "र": 6204, + "ऱ": 6205, + "ल": 6206, + "ळ": 6207, + "व": 6208, + "श": 6209, + "ष": 6210, + "स": 6211, + "ह": 6212, + "़": 6213, + "ा": 6214, + "ि": 6215, + "ी": 6216, + "ु": 6217, + "ू": 6218, + "ृ": 6219, + "ॄ": 6220, + "ॅ": 6221, + "ॆ": 6222, + "े": 6223, + "ै": 6224, + "ॉ": 6225, + "ॊ": 6226, + "ो": 6227, + "ौ": 6228, + "्": 6229, + "ॐ": 6230, + "ॖ": 6231, + "क़": 6232, + "ख़": 6233, + "ग़": 6234, + "ज़": 6235, + "ड़": 6236, + "ढ़": 6237, + "फ़": 6238, + "य़": 6239, + "ॠ": 6240, + "।": 6241, + "॥": 6242, + "०": 6243, + "१": 6244, + "२": 6245, + "३": 6246, + "४": 6247, + "५": 6248, + "६": 6249, + "७": 6250, + "८": 6251, + "९": 6252, + "॰": 6253, + "ॲ": 6254, + "​": 6255, + "‌": 6256, + "‍": 6257, + "‎": 6258, + "₹": 6259, + "के": 6260, + "है": 6261, + "ें": 6262, + "्र": 6263, + "ार": 6264, + "ने": 6265, + "या": 6266, + "में": 6267, + "से": 6268, + "की": 6269, + "का": 6270, + "ों": 6271, + "ता": 6272, + "कर": 6273, + "स्": 6274, + "कि": 6275, + "को": 6276, + "र्": 6277, + "ना": 6278, + "क्": 6279, + "ही": 6280, + "और": 6281, + "पर": 6282, + "ते": 6283, + "हो": 6284, + "प्र": 6285, + "ान": 6286, + "्य": 6287, + "ला": 6288, + "वा": 6289, + "ले": 6290, + "सा": 6291, + "हैं": 6292, + "लि": 6293, + "जा": 6294, + "हा": 6295, + "भी": 6296, + "वि": 6297, + "इस": 6298, + "ती": 6299, + "न्": 6300, + "रा": 6301, + "मा": 6302, + "दे": 6303, + "दि": 6304, + "बा": 6305, + "ति": 6306, + "था": 6307, + "नि": 6308, + "कार": 6309, + "एक": 6310, + "हीं": 6311, + "हु": 6312, + "ंग": 6313, + "ैं": 6314, + "नी": 6315, + "सी": 6316, + "अप": 6317, + "त्": 6318, + "नहीं": 6319, + "री": 6320, + "मे": 6321, + "मु": 6322, + "ित": 6323, + "तो": 6324, + "पा": 6325, + "ली": 6326, + "लिए": 6327, + "गा": 6328, + "ल्": 6329, + "रह": 6330, + "रे": 6331, + "क्ष": 6332, + "मैं": 6333, + "सम": 6334, + "उस": 6335, + "जि": 6336, + "त्र": 6337, + "मि": 6338, + "चा": 6339, + "ोग": 6340, + "सं": 6341, + "द्": 6342, + "सि": 6343, + "आप": 6344, + "तु": 6345, + "दा": 6346, + "कु": 6347, + "यों": 6348, + "वे": 6349, + "जी": 6350, + "्या": 6351, + "उन": 6352, + "िक": 6353, + "ये": 6354, + "भा": 6355, + "्ट": 6356, + "हम": 6357, + "स्ट": 6358, + "शा": 6359, + "ड़": 6360, + "ंद": 6361, + "खा": 6362, + "म्": 6363, + "श्": 6364, + "यह": 6365, + "सक": 6366, + "पू": 6367, + "किया": 6368, + "अपने": 6369, + "रू": 6370, + "सु": 6371, + "मी": 6372, + "हि": 6373, + "जो": 6374, + "थे": 6375, + "रि": 6376, + "दी": 6377, + "थी": 6378, + "गी": 6379, + "लोग": 6380, + "गया": 6381, + "तर": 6382, + "न्ह": 6383, + "च्": 6384, + "वार": 6385, + "बी": 6386, + "प्": 6387, + "दो": 6388, + "टी": 6389, + "शि": 6390, + "करने": 6391, + "गे": 6392, + "ैसे": 6393, + "इन": 6394, + "ंड": 6395, + "साथ": 6396, + "पु": 6397, + "बे": 6398, + "बार": 6399, + "वी": 6400, + "अन": 6401, + "हर": 6402, + "उन्ह": 6403, + "होता": 6404, + "जब": 6405, + "कुछ": 6406, + "मान": 6407, + "क्र": 6408, + "बि": 6409, + "पह": 6410, + "फि": 6411, + "सर": 6412, + "ारी": 6413, + "रो": 6414, + "दू": 6415, + "कहा": 6416, + "तक": 6417, + "शन": 6418, + "ब्": 6419, + "स्थ": 6420, + "वह": 6421, + "बाद": 6422, + "ओं": 6423, + "गु": 6424, + "ज्": 6425, + "्रे": 6426, + "गर": 6427, + "रहे": 6428, + "वर्": 6429, + "हू": 6430, + "ार्": 6431, + "पी": 6432, + "बहु": 6433, + "मुझ": 6434, + "्रा": 6435, + "दिया": 6436, + "सब": 6437, + "करते": 6438, + "अपनी": 6439, + "बहुत": 6440, + "कह": 6441, + "टे": 6442, + "हुए": 6443, + "किसी": 6444, + "रहा": 6445, + "ष्ट": 6446, + "ज़": 6447, + "बना": 6448, + "सो": 6449, + "डि": 6450, + "कोई": 6451, + "व्य": 6452, + "बात": 6453, + "रु": 6454, + "वो": 6455, + "मुझे": 6456, + "द्ध": 6457, + "चार": 6458, + "मेरे": 6459, + "वर": 6460, + "्री": 6461, + "जाता": 6462, + "नों": 6463, + "प्रा": 6464, + "देख": 6465, + "टा": 6466, + "क्या": 6467, + "अध": 6468, + "लग": 6469, + "लो": 6470, + "पि": 6471, + "यु": 6472, + "चे": 6473, + "जिस": 6474, + "ंत": 6475, + "ानी": 6476, + "पै": 6477, + "जन": 6478, + "ारे": 6479, + "ची": 6480, + "मिल": 6481, + "दु": 6482, + "देश": 6483, + "च्छ": 6484, + "ष्": 6485, + "सू": 6486, + "खे": 6487, + "चु": 6488, + "िया": 6489, + "लगा": 6490, + "बु": 6491, + "उनके": 6492, + "ज्ञ": 6493, + "क्षा": 6494, + "तरह": 6495, + "्यादा": 6496, + "वाले": 6497, + "पूर्": 6498, + "मैंने": 6499, + "काम": 6500, + "रूप": 6501, + "होती": 6502, + "उप": 6503, + "जान": 6504, + "प्रकार": 6505, + "भार": 6506, + "मन": 6507, + "हुआ": 6508, + "टर": 6509, + "हूँ": 6510, + "परि": 6511, + "पास": 6512, + "अनु": 6513, + "राज": 6514, + "लोगों": 6515, + "अब": 6516, + "समझ": 6517, + "डी": 6518, + "मौ": 6519, + "शु": 6520, + "चि": 6521, + "पे": 6522, + "कृ": 6523, + "सकते": 6524, + "मह": 6525, + "योग": 6526, + "दर्": 6527, + "उसे": 6528, + "ंध": 6529, + "डा": 6530, + "जाए": 6531, + "बो": 6532, + "ूल": 6533, + "मो": 6534, + "ोंने": 6535, + "ंस": 6536, + "तुम": 6537, + "पहले": 6538, + "बता": 6539, + "तथा": 6540, + "यो": 6541, + "गई": 6542, + "उत्": 6543, + "सकता": 6544, + "कम": 6545, + "ज्यादा": 6546, + "रख": 6547, + "समय": 6548, + "ारा": 6549, + "अगर": 6550, + "स्त": 6551, + "चल": 6552, + "फिर": 6553, + "वारा": 6554, + "करना": 6555, + "शी": 6556, + "गए": 6557, + "बन": 6558, + "ौर": 6559, + "होने": 6560, + "चाह": 6561, + "खु": 6562, + "हाँ": 6563, + "उन्हें": 6564, + "उन्होंने": 6565, + "छो": 6566, + "म्ह": 6567, + "प्रति": 6568, + "निक": 6569, + "वन": 6570, + "्यू": 6571, + "रही": 6572, + "तुम्ह": 6573, + "जैसे": 6574, + "ियों": 6575, + "क्यों": 6576, + "लों": 6577, + "फ़": 6578, + "ंत्र": 6579, + "होते": 6580, + "क्ति": 6581, + "त्य": 6582, + "कर्": 6583, + "कई": 6584, + "वं": 6585, + "किन": 6586, + "पो": 6587, + "कारण": 6588, + "ड़ी": 6589, + "भि": 6590, + "इसके": 6591, + "बर": 6592, + "उसके": 6593, + "द्वारा": 6594, + "शे": 6595, + "कॉ": 6596, + "दिन": 6597, + "न्न": 6598, + "ड़ा": 6599, + "स्व": 6600, + "निर्": 6601, + "मुख": 6602, + "लिया": 6603, + "टि": 6604, + "ज्ञान": 6605, + "क्त": 6606, + "द्र": 6607, + "ग्": 6608, + "क्स": 6609, + "मै": 6610, + "गो": 6611, + "जे": 6612, + "ट्र": 6613, + "मार": 6614, + "त्व": 6615, + "धार": 6616, + "भाव": 6617, + "करता": 6618, + "खि": 6619, + "कं": 6620, + "चाहि": 6621, + "यर": 6622, + "प्त": 6623, + "कों": 6624, + "ंच": 6625, + "जु": 6626, + "मत": 6627, + "अच्छ": 6628, + "हुई": 6629, + "कभी": 6630, + "लेकिन": 6631, + "भू": 6632, + "अपना": 6633, + "दूस": 6634, + "चाहिए": 6635, + "यू": 6636, + "घर": 6637, + "सबसे": 6638, + "मेरी": 6639, + "नाम": 6640, + "ढ़": 6641, + "ंट": 6642, + "ेंगे": 6643, + "बै": 6644, + "फा": 6645, + "एवं": 6646, + "यी": 6647, + "ग्र": 6648, + "क्षे": 6649, + "आज": 6650, + "आपको": 6651, + "भाग": 6652, + "ठा": 6653, + "कै": 6654, + "भारत": 6655, + "उनकी": 6656, + "पहु": 6657, + "सभी": 6658, + "धा": 6659, + "णा": 6660, + "सान": 6661, + "होगा": 6662, + "तब": 6663, + "संग": 6664, + "पर्": 6665, + "अव": 6666, + "तना": 6667, + "गि": 6668, + "यन": 6669, + "स्था": 6670, + "चित": 6671, + "ट्": 6672, + "छा": 6673, + "जाने": 6674, + "क्षेत्र": 6675, + "वाली": 6676, + "पूर्ण": 6677, + "समा": 6678, + "कारी": 6679, + "[hi]": 6680 + }, + "merges": [ + "t h", + "i n", + "th e", + "a n", + "e r", + "o u", + "r e", + "o n", + "a t", + "e d", + "e n", + "t o", + "in g", + "an d", + "i s", + "a s", + "a l", + "o r", + "o f", + "a r", + "i t", + "e s", + "h e", + "s t", + "l e", + "o m", + "s e", + "b e", + "a d", + "o w", + "l y", + "c h", + "w h", + "th at", + "y ou", + "l i", + "v e", + "a c", + "t i", + "l d", + "m e", + "w as", + "g h", + "i d", + "l l", + "w i", + "en t", + "f or", + "a y", + "r o", + "v er", + "i c", + "h er", + "k e", + "h is", + "n o", + "u t", + "u n", + "i r", + "l o", + "w e", + "r i", + "h a", + "wi th", + "gh t", + "ou t", + "i m", + "i on", + "al l", + "a b", + "on e", + "n e", + "g e", + "ou ld", + "t er", + "m o", + "h ad", + "c e", + "s he", + "g o", + "s h", + "u r", + "a m", + "s o", + "p e", + "m y", + "d e", + "a re", + "b ut", + "om e", + "f r", + "the r", + "f e", + "s u", + "d o", + "c on", + "t e", + "a in", + "er e", + "p o", + "i f", + "the y", + "u s", + "a g", + "t r", + "n ow", + "ou n", + "th is", + "ha ve", + "no t", + "s a", + "i l", + "u p", + "th ing", + "fr om", + "a p", + "h im", + "ac k", + "at ion", + "an t", + "ou r", + "o p", + "li ke", + "u st", + "es s", + "b o", + "o k", + "u l", + "in d", + "e x", + "c om", + "s ome", + "the re", + "er s", + "c o", + "re s", + "m an", + "ar d", + "p l", + "w or", + "w ay", + "ti on", + "f o", + "c a", + "w ere", + "b y", + "at e", + "p ro", + "t ed", + "oun d", + "ow n", + "w ould", + "t s", + "wh at", + "q u", + "al ly", + "i ght", + "c k", + "g r", + "wh en", + "v en", + "c an", + "ou gh", + "in e", + "en d", + "p er", + "ou s", + "o d", + "id e", + "k now", + "t y", + "ver y", + "s i", + "a k", + "wh o", + "ab out", + "i ll", + "the m", + "es t", + "re d", + "y e", + "c ould", + "on g", + "you r", + "the ir", + "e m", + "j ust", + "o ther", + "in to", + "an y", + "wh i", + "u m", + "t w", + "as t", + "d er", + "d id", + "i e", + "be en", + "ac e", + "in k", + "it y", + "b ack", + "t ing", + "b r", + "mo re", + "a ke", + "p p", + "the n", + "s p", + "e l", + "u se", + "b l", + "sa id", + "o ver", + "ge t", + "e n", + "e r", + "c h", + "e i", + "i e", + "u n", + "i ch", + "ei n", + "s t", + "a n", + "t e", + "g e", + "a u", + "i n", + "s ch", + "d er", + "un d", + "d ie", + "d a", + "e s", + "a l", + "d en", + "a r", + "g en", + "z u", + "d e", + "h r", + "o n", + "t en", + "e l", + "o r", + "m i", + "s ie", + "da s", + "a t", + "b e", + "ein e", + "ich t", + "b er", + "l e", + "a ch", + "v er", + "s e", + "au f", + "w i", + "s o", + "t er", + "l ich", + "c k", + "u r", + "n icht", + "m m", + "b en", + "a s", + "w ar", + "r e", + "mi t", + "s ich", + "i g", + "l l", + "au s", + "i st", + "w ie", + "o ch", + "un g", + "an n", + "ü r", + "h n", + "i hr", + "s a", + "s en", + "t z", + "de m", + "ei t", + "u m", + "h at", + "wi r", + "v on", + "h a", + "s p", + "w ei", + "i er", + "r o", + "h er", + "r a", + "ein en", + "n e", + "v or", + "al s", + "an d", + "al l", + "w as", + "w o", + "r ei", + "st e", + "l ie", + "au ch", + "d u", + "d es", + "k o", + "ü ber", + "a m", + "b ei", + "h en", + "h m", + "l ei", + "a ber", + "w en", + "h l", + "g er", + "i m", + "u t", + "n ach", + "h e", + "i s", + "b r", + "f t", + "en t", + "i mm", + "j e", + "sch en", + "w er", + "s er", + "a b", + "ä n", + "m e", + "s ein", + "i t", + "o l", + "ch t", + "f ür", + "k l", + "f f", + "eine m", + "n en", + "w e", + "j a", + "u s", + "n och", + "hat te", + "t r", + "p f", + "h in", + "d i", + "ch en", + "b l", + "m an", + "r ü", + "ie l", + "s el", + "das s", + "i hn", + "mi r", + "sch l", + "ö n", + "g an", + "g t", + "ein er", + "st en", + "m ich", + "wen n", + "el l", + "g te", + "in d", + "m al", + "ge l", + "k en", + "n ur", + "mm en", + "f ü", + "er n", + "ö r", + "un ter", + "f r", + "an der", + "g r", + "i l", + "d ur", + "u ch", + "f e", + "t a", + "m en", + "m ach", + "d och", + "t i", + "dur ch", + "o s", + "g l", + "h al", + "ihr e", + "w ä", + "imm er", + "i hm", + "k ann", + "or t", + "d ann", + "l an", + "tz t", + "o der", + "hr en", + "e t", + "k ön", + "i ck", + "f a", + "in g", + "i r", + "wie der", + "da ß", + "m ein", + "f en", + "gan z", + "die se", + "st er", + "da r", + "w a", + "ge s", + "n a", + "f l", + "i gen", + "sch e", + "un gen", + "me hr", + "ß en", + "o t", + "k on", + "ge w", + "ha ben", + "ge h", + "ä t", + "s ind", + "d r", + "w el", + "un s", + "v o", + "m a", + "u te", + "sch on", + "b es", + "ge sch", + "b t", + "ch e", + "s on", + "o b", + "l a", + "p p", + "rü ck", + "s eine", + "k r", + "f re", + "ei l", + "zu m", + "u l", + "h ier", + "k t", + "i ge", + "sp r", + "k e", + "le ben", + "b st", + "z eit", + "i on", + "g ro", + "den n", + "h o", + "sch a", + "b ar", + "al le", + "ge gen", + "w ür", + "m ü", + "z e", + "wer den", + "je tzt", + "ko mmen", + "n ie", + "s ei", + "h eit", + "so ll", + "g lei", + "m eine", + "wo ll", + "n er", + "ha be", + "w ur", + "lich en", + "p er", + "as sen", + "n te", + "se hen", + "wir d", + "b is", + "g ar", + "i en", + "m us", + "u ß", + "ä r", + "st ell", + "k eit", + "z wei", + "sel bst", + "st a", + "p a", + "sa gte", + "te t", + "k am", + "s sen", + "v iel", + "u g", + "z en", + "h ei", + "m ann", + "wi ll", + "ge b", + "war en", + "ü ck", + "ä ch", + "m er", + "r u", + "w or", + "h au", + "ei gen", + "an g", + "we g", + "bl ick", + "f ra", + "all es", + "k a", + "au gen", + "f in", + "lich e", + "t o", + "un ser", + "der n", + "her r", + "n un", + "v ie", + "ch te", + "wo hl", + "f all", + "h t", + "ü n", + "et was", + "st and", + "en d", + "ä u", + "e m", + "m ö", + "te l", + "r ie", + "d ich", + "die s", + "h and", + "b in", + "ff en", + "nicht s", + "d an", + "p l", + "hn e", + "ihn en", + "es en", + "die ser", + "fr au", + "an t", + "ar t", + "di r", + "i sch", + "er st", + "glei ch", + "ko mm", + "h ör", + "ß e", + "d ig", + "se hr", + "z ei", + "sa m", + "au m", + "h ät", + "in gen", + "g ut", + "b o", + "m ut", + "ck en", + "kon nte", + "st imm", + "p ro", + "zu r", + "i tz", + "wei l", + "wür de", + "f ä", + "kön nen", + "k eine", + "f er", + "i schen", + "vo ll", + "ein es", + "se tz", + "z ie", + "de l", + "te te", + "sein er", + "ier en", + "ge st", + "zu rück", + "wur de", + "sch n", + "p r", + "lie ß", + "t ra", + "m ä", + "gen d", + "f ol", + "i k", + "schl a", + "scha ft", + "at er", + "wei ß", + "s einen", + "l assen", + "l u", + "und en", + "t eil", + "ne u", + "ier t", + "men schen", + "hm en", + "st r", + "g i", + "sa h", + "ihr en", + "el n", + "wei ter", + "ge hen", + "ig er", + "mach t", + "ta g", + "al so", + "hal ten", + "n is", + "ach t", + "ge ben", + "f or", + "o g", + "n at", + "m ar", + "de t", + "o hne", + "h aus", + "t ro", + "an ge", + "l au", + "sp iel", + "t re", + "sch r", + "in n", + "s u", + "l os", + "mach en", + "hät te", + "be g", + "wir k", + "al t", + "g lich", + "te s", + "r icht", + "fre und", + "m o", + "ihr er", + "f el", + "b el", + "so l", + "ein mal", + "e ben", + "h ol", + "h än", + "q u", + "ter n", + "h ö", + "sch w", + "re cht", + "wa hr", + "s einem", + "ste hen", + "hl en", + "in s", + "g ing", + "woll te", + "wi ssen", + "ung s", + "al d", + "as s", + "ja hr", + "m or", + "wel t", + "un der", + "zu sa", + "at ion", + "ko pf", + "lan g", + "hin ter", + "at z", + "st ra", + "an gen", + "an k", + "a de", + "gl au", + "f ach", + "hat ten", + "l o", + "f ort", + "ei cht", + "i ff", + "l er", + "m ei", + "diese m", + "k ein", + "f rei", + "fü hr", + "vo m", + "e s", + "e n", + "a i", + "o u", + "o n", + "l e", + "d e", + "r e", + "q u", + "a n", + "e r", + "en t", + "e t", + "l a", + "n e", + "i l", + "a r", + "i s", + "ai t", + "t e", + "a u", + "i n", + "qu e", + "i t", + "u r", + "s e", + "l es", + "c h", + "c e", + "m e", + "o r", + "ou r", + "a s", + "p r", + "a v", + "o m", + "ai s", + "u n", + "an t", + "ou s", + "t r", + "t i", + "l u", + "o i", + "e u", + "l le", + "s i", + "p ar", + "d es", + "an s", + "m ent", + "é t", + "es t", + "j e", + "u ne", + "a l", + "p as", + "t re", + "qu i", + "d u", + "r i", + "c on", + "s on", + "c om", + "e lle", + "d é", + "p our", + "d ans", + "l i", + "s a", + "r é", + "t ou", + "v ous", + "d i", + "v i", + "a g", + "a m", + "a t", + "ou v", + "a p", + "ti on", + "m on", + "s ur", + "c i", + "o s", + "p lu", + "s u", + "en d", + "a b", + "è re", + "ai n", + "m ais", + "o is", + "r es", + "plu s", + "é e", + "ai ent", + "m p", + "ch e", + "lu i", + "av e", + "ét ait", + "m a", + "s es", + "tou t", + "i r", + "v o", + "a c", + "s er", + "an d", + "f f", + "oi r", + "g r", + "av ait", + "é s", + "m es", + "n ous", + "eu x", + "b i", + "t er", + "c o", + "on s", + "p u", + "c es", + "g e", + "t u", + "le ur", + "pr o", + "d on", + "e ur", + "et te", + "ai re", + "ave c", + "d it", + "t é", + "i e", + "u s", + "il le", + "p er", + "com me", + "c r", + "or t", + "m i", + "e x", + "u x", + "v er", + "m o", + "è s", + "v e", + "au x", + "r a", + "j our", + "il s", + "bi en", + "c ou", + "p e", + "que l", + "p eu", + "c ette", + "t es", + "p o", + "in s", + "c u", + "m ê", + "s o", + "f ait", + "g u", + "m ar", + "ê tre", + "l o", + "it é", + "f r", + "a tion", + "en s", + "b r", + "n i", + "l é", + "d is", + "b le", + "m an", + "n é", + "pu is", + "mê me", + "qu es", + "f i", + "e l", + "ag e", + "g ar", + "m oi", + "en ce", + "on t", + "m ain", + "or s", + "au t", + "an ce", + "v en", + "m é", + "s ans", + "e m", + "s é", + "l on", + "h om", + "r o", + "u t", + "c ar", + "ab le", + "i m", + "de r", + "ch er", + "n o", + "vi e", + "au s", + "b e", + "de ux", + "en f", + "o ù", + "t en", + "p h", + "u re", + "te mp", + "p os", + "r ent", + "p é", + "f aire", + "p i", + "tr es", + "ç a", + "an g", + "end re", + "f or", + "p a", + "b on", + "s ou", + "in t", + "pr é", + "s ent", + "t ant", + "n er", + "c er", + "l à", + "l ais", + "pr ès", + "b re", + "c our", + "p et", + "i on", + "i ne", + "com p", + "l ait", + "tr ouv", + "t a", + "ent re", + "son t", + "de v", + "n u", + "temp s", + "d ou", + "r ait", + "b ou", + "qu and", + "jour s", + "l an", + "er s", + "av oir", + "ét é", + "a le", + "p re", + "f ois", + "or te", + "v é", + "m er", + "n on", + "t ous", + "j us", + "cou p", + "t s", + "hom me", + "ê te", + "a d", + "aus si", + "ur s", + "se u", + "or d", + "o b", + "m in", + "g é", + "co re", + "v a", + "v re", + "en core", + "se m", + "i te", + "au tre", + "pr is", + "peu t", + "u e", + "an te", + "m al", + "g n", + "ré p", + "h u", + "si on", + "vo tre", + "di re", + "e z", + "f em", + "leur s", + "m et", + "f in", + "c ri", + "m is", + "t our", + "r ai", + "j am", + "re gar", + "ri en", + "ver s", + "su is", + "p ouv", + "o p", + "v is", + "gr and", + "ant s", + "c or", + "re r", + "ar d", + "c é", + "t ent", + "pr es", + "v ou", + "f a", + "al ors", + "si eur", + "ai ne", + "le r", + "qu oi", + "f on", + "end ant", + "ar ri", + "eu re", + "a près", + "don c", + "it u", + "l è", + "s ait", + "t oi", + "ch a", + "ai l", + "as se", + "i mp", + "vo y", + "con n", + "p la", + "pet it", + "av ant", + "n om", + "t in", + "don t", + "d a", + "s ous", + "e mp", + "per son", + "el les", + "be au", + "par ti", + "ch o", + "pr it", + "tou jours", + "m en", + "r ais", + "jam ais", + "tr av", + "tion s", + "tr ès", + "v oi", + "r en", + "y eux", + "f er", + "v oir", + "pre mi", + "c a", + "g ne", + "h eure", + "r ou", + "e ff", + "no tre", + "ment s", + "t on", + "f ais", + "ce la", + "i er", + "rép on", + "con s", + "ai r", + "ô t", + "p endant", + "i ci", + "tou te", + "j et", + "p ort", + "ét aient", + "p en", + "h é", + "au tres", + "p ère", + "o c", + "quel ques", + "i que", + "l is", + "fem me", + "j ou", + "te ur", + "mon de", + "u se", + "n es", + "d re", + "a ff", + "r ap", + "par t", + "le ment", + "c la", + "f ut", + "quel que", + "pr endre", + "r ê", + "ai lle", + "s ais", + "ch es", + "le t", + "ch ar", + "è res", + "ent s", + "b er", + "g er", + "mo ins", + "e au", + "a î", + "j eu", + "h eur", + "é es", + "tr i", + "po int", + "m om", + "v ent", + "n ouv", + "gr an", + "tr ois", + "s ant", + "tout es", + "con tre", + "è rent", + "che z", + "ave z", + "û t", + "a lle", + "at t", + "p au", + "p orte", + "ouv er", + "b ar", + "l it", + "f ort", + "o t", + "as s", + "pr és", + "cho se", + "v it", + "mon sieur", + "h ab", + "t ête", + "j u", + "te ment", + "c tion", + "v rai", + "la r", + "c et", + "regar d", + "l ant", + "de m", + "s om", + "mom ent", + "il les", + "p le", + "p s", + "b es", + "m ère", + "c l", + "s our", + "y s", + "tr op", + "en ne", + "jus qu", + "av aient", + "av ais", + "jeu ne", + "de puis", + "person ne", + "f it", + "cer t", + "j o", + "g es", + "ou i", + "r est", + "sem b", + "c ap", + "m at", + "m u", + "lon g", + "fr an", + "f aut", + "it i", + "b li", + "che v", + "pr i", + "ent e", + "ain si", + "ch am", + "l ors", + "c as", + "d o", + "il i", + "b é", + "n os", + "an ge", + "su i", + "r it", + "cr o", + "gu e", + "d e", + "e n", + "e s", + "o s", + "l a", + "e r", + "q u", + "a r", + "a n", + "o n", + "qu e", + "a s", + "o r", + "e l", + "d o", + "a l", + "c i", + "u n", + "r e", + "a b", + "i n", + "t e", + "t o", + "s e", + "d i", + "t r", + "d a", + "c on", + "t a", + "s u", + "m i", + "c o", + "t i", + "l e", + "l os", + "n o", + "l o", + "í a", + "c u", + "c a", + "s i", + "v i", + "m e", + "p or", + "m o", + "p ar", + "r a", + "r i", + "la s", + "c h", + "r o", + "m a", + "p er", + "ó n", + "m en", + "de s", + "un a", + "m p", + "s o", + "ab a", + "p u", + "d os", + "t u", + "g u", + "er a", + "de l", + "h a", + "m u", + "l i", + "en t", + "m b", + "h ab", + "es t", + "g o", + "p a", + "r es", + "par a", + "p o", + "á s", + "m os", + "tr a", + "t en", + "an do", + "p i", + "qu i", + "b i", + "m an", + "co mo", + "v e", + "m ás", + "j o", + "ci ón", + "i s", + "t an", + "v o", + "da d", + "c e", + "a do", + "v er", + "f u", + "ci a", + "c er", + "p e", + "c as", + "c ar", + "men te", + "n i", + "su s", + "t ar", + "n a", + "f i", + "t er", + "z a", + "p ro", + "tr o", + "s a", + "l u", + "b a", + "per o", + "s er", + "c es", + "d as", + "d u", + "s in", + "e mp", + "m ar", + "l la", + "e x", + "á n", + "c or", + "i a", + "v a", + "r an", + "ch o", + "g a", + "y o", + "t os", + "c os", + "mi s", + "l es", + "t es", + "v en", + "h o", + "y a", + "en te", + "on es", + "hab ía", + "n u", + "u s", + "p as", + "h i", + "n os", + "es ta", + "la n", + "m as", + "t or", + "l le", + "h e", + "s on", + "b re", + "p re", + "ab an", + "d or", + "í an", + "i r", + "t as", + "é n", + "r u", + "en do", + "a que", + "er o", + "i o", + "qu é", + "m in", + "c ab", + "j a", + "de r", + "t al", + "é s", + "se ñ", + "or a", + "to do", + "la r", + "d on", + "g ar", + "s al", + "p r", + "cu ando", + "j e", + "h u", + "g un", + "b u", + "g i", + "d ar", + "n e", + "r as", + "de n", + "es to", + "par e", + "p en", + "é l", + "tr as", + "c an", + "b o", + "j os", + "mi en", + "pu e", + "c re", + "co mp", + "p on", + "d ía", + "tr os", + "s ab", + "so bre", + "es e", + "mb re", + "er on", + "a ñ", + "m or", + "f or", + "i do", + "por que", + "el la", + "p ri", + "g ran", + "f a", + "c en", + "di s", + "c ri", + "mu y", + "ch a", + "c al", + "es te", + "h as", + "c ó", + "g ra", + "r os", + "p os", + "o b", + "al l", + "aque l", + "j u", + "p res", + "m er", + "di jo", + "c ía", + "ent re", + "z o", + "ci ones", + "bi en", + "mb i", + "el o", + "t ó", + "in a", + "to dos", + "g en", + "ti en", + "est aba", + "de ci", + "ci o", + "h er", + "ñ o", + "l or", + "nu es", + "me di", + "l en", + "vi da", + "f e", + "al i", + "m on", + "c la", + "d re", + "pu es", + "al es", + "vo l", + "m í", + "r ar", + "b le", + "ci on", + "has ta", + "señ or", + "con o", + "a h", + "di os", + "s en", + "es a", + "ú n", + "v ar", + "s an", + "gu i", + "a c", + "o tros", + "ta do", + "bu en", + "ñ a", + "ti emp", + "ha cer", + "j er", + "f er", + "v u", + "f in", + "an a", + "as í", + "an tes", + "t in", + "ve z", + "mien to", + "j ar", + "la b", + "ch e", + "cas a", + "d r", + "es o", + "e go", + "di ó", + "an te", + "est á", + "m al", + "en cia", + "el i", + "í as", + "tiemp o", + "z ar", + "v an", + "m un", + "er ta", + "ta mbi", + "s í", + "b ar", + "a un", + "al e", + "mis mo", + "ent es", + "vi s", + "man o", + "el e", + "na da", + "se gu", + "me j", + "er ra", + "ab le", + "b e", + "ti r", + "un o", + "don de", + "to da", + "des de", + "r en", + "tambi én", + "cu er", + "per son", + "ho mbre", + "o tro", + "li b", + "tr ar", + "cu al", + "ha y", + "a u", + "ca da", + "t aba", + "i mp", + "men to", + "ten ía", + "qu er", + "er an", + "si emp", + "siemp re", + "er to", + "qu í", + "g os", + "pu és", + "el los", + "des pués", + "nu e", + "g an", + "l lo", + "in ter", + "có mo", + "tr i", + "ah ora", + "us te", + "tr aba", + "la do", + "in o", + "po co", + "er te", + "mu jer", + "i m", + "qui er", + "al gun", + "fu e", + "o jos", + "ent on", + "v os", + "es per", + "mu ch", + "o tra", + "a z", + "a d", + "in g", + "e za", + "a quí", + "ci as", + "gu a", + "mu cho", + "deci r", + "es ti", + "i dad", + "al go", + "e z", + "o cu", + "enton ces", + "di do", + "ent os", + "g ri", + "da do", + "i os", + "so l", + "dos e", + "uste d", + "qui en", + "a mi", + "un to", + "f r", + "mi r", + "mej or", + "b as", + "so lo", + "pre gun", + "tu r", + "al g", + "p la", + "to das", + "par te", + "e mb", + "c to", + "mun do", + "tien e", + "tan te", + "pa lab", + "tr an", + "aque lla", + "ci os", + "aun que", + "a y", + "cu en", + "ten er", + "f un", + "res pon", + "all í", + "x i", + "h an", + "pen s", + "con tra", + "tu ra", + "v al", + "di o", + "tr es", + "t re", + "tan to", + "ca min", + "m ó", + "es p", + "a da", + "í o", + "in s", + "ha cia", + "de j", + "est ar", + "i ón", + "g as", + "b er", + "v as", + "no che", + "é r", + "añ os", + "pa dre", + "gu s", + "á r", + "sin o", + "man os", + "ci do", + "es tu", + "a de", + "hu bi", + "vi r", + "b ri", + "ra z", + "ch i", + "pue de", + "men os", + "hab i", + "ho mb", + "ne ces", + "ma y", + "er os", + "r ía", + "he cho", + "es cu", + "l ti", + "án do", + "b us", + "cos as", + "t ú", + "es pa", + "re ci", + "c tor", + "pri m", + "di a", + "de se", + "mien tras", + "h or", + "fu er", + "i da", + "pos i", + "lan te", + "t on", + "an o", + "est as", + "p li", + "ch ar", + "lu ego", + "si ón", + "ci n", + "ti erra", + "m es", + "gu ar", + "ca do", + "en con", + "pr en", + "may or", + "f al", + "e r", + "o n", + "a n", + "t o", + "d i", + "r e", + "l a", + "i n", + "e n", + "a l", + "t a", + "c h", + "e l", + "r i", + "c o", + "t i", + "t e", + "s i", + "r a", + "u n", + "l e", + "l i", + "ch e", + "r o", + "c i", + "c a", + "s e", + "q u", + "m a", + "p o", + "s o", + "i l", + "d o", + "e s", + "v a", + "p er", + "l o", + "c on", + "d el", + "p a", + "m o", + "s a", + "p i", + "d a", + "m i", + "g i", + "s u", + "d e", + "v i", + "z i", + "m e", + "g li", + "n o", + "m en", + "v o", + "t u", + "n on", + "v e", + "t to", + "s t", + "on e", + "an o", + "ch i", + "er a", + "er e", + "f a", + "c e", + "z a", + "un a", + "b i", + "p re", + "s ta", + "o r", + "a r", + "f i", + "on o", + "t ra", + "n a", + "n el", + "n e", + "p ro", + "t ro", + "al e", + "v er", + "n i", + "c u", + "t ti", + "men te", + "del la", + "t er", + "zi one", + "g u", + "p e", + "t ta", + "an do", + "t à", + "al i", + "u o", + "qu el", + "co m", + "s en", + "co me", + "b a", + "al la", + "p ri", + "d u", + "qu es", + "l u", + "on i", + "g gi", + "pa r", + "s si", + "v en", + "in a", + "g a", + "pi ù", + "ci a", + "i m", + "co r", + "m an", + "in o", + "in i", + "t en", + "r an", + "b b", + "g o", + "s to", + "t re", + "a ve", + "a v", + "s ono", + "er i", + "a c", + "s se", + "er o", + "h a", + "s c", + "su l", + "f or", + "v ano", + "po r", + "s ti", + "su o", + "c chi", + "t an", + "z za", + "an che", + "p u", + "i o", + "t te", + "vo l", + "es s", + "s ci", + "co l", + "r u", + "p en", + "f u", + "al l", + "s so", + "s te", + "se m", + "s sa", + "d en", + "a d", + "t ri", + "de i", + "in e", + "ave va", + "men to", + "z z", + "a mo", + "g no", + "f o", + "un o", + "su a", + "g en", + "ri a", + "g e", + "st ra", + "s ì", + "c er", + "ch é", + "b u", + "a p", + "c en", + "d al", + "on a", + "s pe", + "g ni", + "b o", + "t t", + "del le", + "ques to", + "nel la", + "f f", + "d ere", + "an no", + "del l", + "un i", + "bb e", + "an ti", + "g ra", + "s p", + "en e", + "gi o", + "u to", + "qu al", + "gli a", + "qu ando", + "tu tto", + "c an", + "gli o", + "zi oni", + "ca m", + "h o", + "es so", + "s s", + "mo l", + "a t", + "lo ro", + "per ché", + "co sa", + "du e", + "po i", + "ca r", + "s co", + "ci o", + "to r", + "c co", + "c re", + "a m", + "g na", + "te m", + "pri ma", + "lu i", + "co sì", + "qu e", + "gu ar", + "ess ere", + "an i", + "con o", + "b ra", + "al le", + "m on", + "ri o", + "an co", + "cu i", + "s pi", + "vi a", + "g ran", + "gi or", + "a i", + "bi le", + "u l", + "ggi o", + "f e", + "an te", + "ma i", + "ta re", + "in ter", + "in di", + "re bbe", + "sen za", + "so lo", + "zi o", + "e d", + "en te", + "tu tti", + "sta to", + "zi a", + "d alla", + "tu ra", + "mi a", + "vi ta", + "quel la", + "qu a", + "ma r", + "do ve", + "g h", + "al lo", + "sem pre", + "zz o", + "si a", + "mo r", + "do po", + "por ta", + "d re", + "c cia", + "er ano", + "an ni", + "di o", + "chi a", + "en za", + "pro pri", + "qu i", + "m u", + "m b", + "an da", + "c ca", + "o cchi", + "ques ta", + "f fi", + "le i", + "par te", + "d on", + "r on", + "mi o", + "tan to", + "ri s", + "o gni", + "di s", + "r in", + "fa r", + "men ti", + "t el", + "anco ra", + "f ra", + "fa tto", + "man i", + "sen ti", + "p ra", + "tem po", + "es si", + "b bi", + "f in", + "a re", + "la re", + "per s", + "f on", + "b el", + "so r", + "d er", + "pre n", + "an za", + "di re", + "pi e", + "o ra", + "ver so", + "se gu", + "al tro", + "ta to", + "ca to", + "a to", + "vol ta", + "c c", + "fa re", + "pa re", + "ci ò", + "li b", + "bi li", + "n uo", + "s er", + "quel lo", + "co lo", + "p po", + "ca sa", + "tro va", + "o re", + "f er", + "r ono", + "d es", + "mol to", + "al mente", + "s ca", + "vo le", + "t ali", + "sul la", + "s ce", + "men o", + "an to", + "p un", + "s tu", + "ca pi", + "so l", + "gi u", + "m ini", + "m ano", + "z e", + "pi a", + "par ti", + "s al", + "la vo", + "ver o", + "r si", + "al tri", + "es ti", + "s cia", + "suo i", + "gli e", + "so tto", + "b ene", + "sc ri", + "t ale", + "de gli", + "n u", + "al c", + "uo mo", + "p el", + "f re", + "po te", + "es sa", + "s cu", + "si gno", + "el e", + "st ro", + "u ti", + "di a", + "si one", + "g re", + "f ini", + "ar ri", + "l un", + "c ri", + "e si", + "pa ssa", + "r à", + "men tre", + "an d", + "h anno", + "el o", + "u sci", + "gi a", + "gi à", + "di e", + "m ina", + "b e", + "ti ca", + "gior no", + "t in", + "es se", + "mo do", + "c al", + "s pa", + "propri o", + "l en", + "o ri", + "con tro", + "st ru", + "di ven", + "di sse", + "ra to", + "no i", + "v ere", + "pu ò", + "di ce", + "s an", + "es a", + "c ci", + "se con", + "re n", + "c cio", + "qual che", + "tu tta", + "g g", + "mon do", + "for ma", + "p li", + "m ma", + "pen sa", + "de va", + "tu r", + "fo sse", + "so pra", + "ta mente", + "n ess", + "qu anto", + "ra ga", + "un que", + "ca re", + "st re", + "gran de", + "pi cco", + "guar da", + "b en", + "nel l", + "a ff", + "po ssi", + "pre sen", + "r ò", + "pa ro", + "tu a", + "v in", + "an e", + "a s", + "ste sso", + "da v", + "ne i", + "nel le", + "gh i", + "pi o", + "ta r", + "an a", + "la to", + "si d", + "f ine", + "f uo", + "m er", + "z o", + "qua si", + "ul ti", + "i to", + "su e", + "si e", + "f il", + "allo ra", + "m in", + "ven i", + "t ano", + "el lo", + "d e", + "r a", + "e s", + "d o", + "e n", + "q u", + "c o", + "a s", + "o s", + "e r", + "a r", + "s e", + "qu e", + "a n", + "i n", + "i s", + "t o", + "ã o", + "t e", + "d a", + "m a", + "e l", + "t a", + "o r", + "i a", + "r e", + "e m", + "a l", + "co m", + "p a", + "o u", + "c a", + "u m", + "r o", + "v a", + "t i", + "s o", + "m en", + "n ão", + "h a", + "co n", + "m e", + "r i", + "pa ra", + "p o", + "d i", + "s a", + "v o", + "u ma", + "c i", + "n a", + "p or", + "n o", + "g u", + "s u", + "h o", + "an do", + "t ra", + "e i", + "v i", + "e u", + "i m", + "do s", + "el e", + "r es", + "m o", + "en t", + "f i", + "l a", + "e ra", + "l e", + "de s", + "el a", + "men te", + "l h", + "p er", + "l i", + "ç ão", + "m as", + "t er", + "m u", + "es t", + "v e", + "g o", + "l o", + "u s", + "ma is", + "v er", + "c ê", + "in ha", + "vo cê", + "f a", + "t u", + "c u", + "p ar", + "com o", + "p ro", + "s i", + "m os", + "e c", + "p re", + "d as", + "ç a", + "es ta", + "s er", + "u n", + "da de", + "d is", + "f o", + "e x", + "c h", + "i r", + "ra n", + "t ar", + "en te", + "g a", + "t r", + "p e", + "t os", + "b o", + "c ia", + "p en", + "c ar", + "s en", + "su a", + "se m", + "c as", + "f or", + "to u", + "n os", + "te m", + "r ia", + "m es", + "se u", + "co r", + "o n", + "a o", + "p os", + "ra m", + "v el", + "é m", + "t en", + "po de", + "t es", + "esta va", + "c e", + "b a", + "qu ando", + "m i", + "qu er", + "men to", + "se gu", + "t as", + "is so", + "mu i", + "g ar", + "t ro", + "d u", + "fa z", + "õ es", + "p es", + "an to", + "l u", + "p i", + "i x", + "ve z", + "s im", + "j a", + "p r", + "m in", + "b e", + "ra s", + "m an", + "p res", + "est á", + "c er", + "b re", + "p as", + "d ia", + "m b", + "dis se", + "n i", + "r os", + "es se", + "v ia", + "o lh", + "is a", + "an te", + "ê n", + "z a", + "qu i", + "b i", + "t inha", + "me u", + "s ão", + "m inha", + "a c", + "ri o", + "m ar", + "a t", + "p el", + "mui to", + "ta l", + "to r", + "fo i", + "h or", + "j o", + "b em", + "g i", + "f al", + "vo l", + "po n", + "di z", + "l ar", + "gu n", + "m or", + "r u", + "par ec", + "ç o", + "do r", + "pes so", + "n e", + "f er", + "b er", + "p u", + "po is", + "in a", + "es p", + "d ar", + "en do", + "de n", + "so bre", + "co s", + "p ri", + "al i", + "mes mo", + "ç ões", + "g ra", + "se us", + "me i", + "b ra", + "vi da", + "an tes", + "b ri", + "at é", + "ên cia", + "lh e", + "ti v", + "m ã", + "al g", + "qu anto", + "s ó", + "g os", + "de r", + "t ão", + "tu do", + "ent ão", + "r ou", + "es s", + "in da", + "b al", + "in do", + "ci o", + "n do", + "j á", + "va m", + "re i", + "l es", + "ei to", + "v is", + "tem po", + "de pois", + "c ha", + "m el", + "ch e", + "l ha", + "a inda", + "faz er", + "con tra", + "p ou", + "per gun", + "de ix", + "ta mb", + "ra r", + "al a", + "v en", + "t in", + "pel o", + "tamb ém", + "fi ca", + "pre c", + "el es", + "tra n", + "ha via", + "l á", + "to dos", + "j u", + "qu al", + "c an", + "ta do", + "cas a", + "es sa", + "n as", + "g em", + "m em", + "se i", + "na da", + "sen ti", + "c ri", + "ó s", + "de u", + "ei ro", + ". .", + "f un", + "as sim", + "s ou", + "ent re", + "com e", + "i or", + "h ar", + "f e", + "por que", + "s or", + "f in", + "ta mente", + "a qui", + "cu l", + "t ó", + "for ma", + "s ar", + "ou tra", + "olh os", + "i ma", + "m im", + "a go", + "in s", + "co u", + "g ran", + "v al", + "pesso as", + "era m", + "ei ra", + "a que", + "com p", + "de i", + "p ela", + "co isa", + "m ão", + "con h", + "ca da", + "ago ra", + "ia m", + "h á", + "con s", + "su as", + "gu ém", + "o b", + "l an", + "es ti", + "á s", + "la do", + "in ter", + "ca be", + "por ta", + "n em", + "í vel", + "r is", + "j e", + "n un", + "sem pre", + "con segu", + "h as", + "tra bal", + "f u", + "le v", + "l em", + "l as", + "va i", + "tr os", + "t ante", + "te i", + "pr ó", + "que m", + "tu ra", + "on de", + "cabe ça", + "nun ca", + "men tos", + "h um", + "de le", + "ver dade", + "t á", + "h os", + "el i", + "ent es", + "m er", + "alg um", + "diz er", + "s in", + "pen as", + "n ós", + "en quanto", + "ou tro", + "l ho", + "es te", + "mel hor", + "est ar", + "g an", + "b ar", + "pri mei", + "a u", + "i u", + "pen sa", + "a penas", + "p ra", + "es tou", + "con te", + "res pon", + "ho mem", + "do is", + "a do", + "c al", + "a b", + "l os", + "ç as", + "pou co", + "sen hor", + "t ando", + "esp era", + "pa i", + "ri os", + "no i", + "i da", + "ba ix", + "as e", + "is as", + "f r", + "ho ra", + "mu ndo", + "pas sa", + "fi car", + "to do", + "se ja", + "al mente", + "â n", + "c lar", + "a d", + "in c", + "f os", + "lo n", + "g ri", + "ou vi", + "v em", + "g e", + "ta va", + "á rio", + "mo n", + "s os", + "in ho", + "ma l", + "t an", + "t re", + "gran de", + "ran do", + "b u", + "v ou", + "ê s", + "co isas", + "a conte", + "lh er", + "g en", + "ci on", + "an os", + "i do", + "tal vez", + "est ão", + "li v", + "sa b", + "su r", + "ou tros", + "c re", + "qual quer", + "g ou", + "t ri", + "l í", + "tiv esse", + "ra do", + "prec isa", + "mã e", + "su s", + "t anto", + "de la", + "men os", + "s al", + "en tra", + "p é", + "ma ior", + "noi te", + "ti va", + "p ala", + "so n", + "ra ção", + "de us", + "s as", + "un i", + "l or", + "u l", + "in te", + "f ei", + "an o", + "par ti", + "pala v", + "tr ás", + "par te", + "b el", + "ci dade", + "lu gar", + "v os", + "vez es", + "do u", + "en contra", + "tr u", + "e ci", + "a r", + "e r", + "a n", + "e n", + "i n", + "i r", + "o r", + "d e", + "a k", + "ı n", + "a l", + "d i", + "d a", + "b u", + "b ir", + "y or", + "i l", + "e k", + "y a", + "m a", + "l a", + "e l", + "u n", + "k a", + "l ar", + "i m", + "d ı", + "e t", + "o n", + "d u", + "o l", + "e y", + "t ı", + "m i", + "h a", + "b a", + "l er", + "ü n", + "m ı", + "i z", + "l e", + "ı r", + "m e", + "i s", + "n e", + "o k", + "t a", + "s a", + "u m", + "r a", + "g ö", + "i k", + "s ı", + "d en", + "e s", + "b il", + "t i", + "l ı", + "ü z", + "i ç", + "ü r", + "g i", + "u r", + "t e", + "b en", + "d an", + "i y", + "ı m", + "u z", + "v e", + "c ak", + "a y", + "c e", + "i ş", + "ın ı", + "i yor", + "ba ş", + "d ü", + "a t", + "a m", + "g el", + "de ğ", + "k ar", + "i ̇", + "m u", + "e v", + "ö y", + "bu n", + "v ar", + "ya p", + "s en", + "an a", + "s un", + "in i", + "gö r", + "y ı", + "k i", + "l i", + "ar a", + "al ı", + "on u", + "ç ı", + "ş ey", + "s ın", + "k ı", + "ka d", + "s e", + "t an", + "a ğ", + "değ il", + "s in", + "ü k", + "a z", + "ç ok", + "s on", + "ş ı", + "b i", + "ü l", + "t u", + "v er", + "iç in", + "g e", + "k en", + "ey e", + "ol du", + "mı ş", + "y e", + "k al", + "m ek", + "l an", + "öy le", + "yor du", + "er i", + "y üz", + "mi ş", + "b e", + "m ak", + "o la", + "in e", + "y an", + "h er", + "c ek", + "yor um", + "b ak", + "ü m", + "ö n", + "lar ı", + "o ğ", + "d er", + "kad ar", + "h al", + "ar ı", + "s t", + "s an", + "ın da", + "du r", + "g ün", + "v a", + "y ok", + "y er", + "dı m", + "k o", + "da ha", + "l u", + "ın a", + "di m", + "e m", + "bil ir", + "ik i", + "s iz", + "s i", + "n a", + "di ğ", + "s u", + "b ü", + "ha y", + "s or", + "dü ş", + "ü ç", + "un u", + "ö r", + "d ir", + "m ü", + "c a", + "am an", + "f ak", + "a da", + "e de", + "son ra", + "h iç", + "ak i", + "ğ ı", + "bu l", + "r u", + "ma z", + "an la", + "bu ra", + "ge ç", + "ma ya", + "l en", + "k onu", + "c i", + "c u", + "d in", + "t ek", + "z aman", + "el er", + "ö z", + "dı r", + "gi bi", + "o t", + "ş a", + "g er", + "ler i", + "k im", + "k u", + "fak at", + "y ar", + "gö z", + "c ı", + "yor sun", + "b ek", + "in de", + "r o", + "p ek", + "bun u", + "l ik", + "m an", + "il er", + "e di", + "ö l", + "s ür", + "b in", + "s ır", + "çı k", + "sı l", + "al ar", + "k es", + "y ak", + "ç ek", + "yı l", + "e cek", + "ı z", + "gi t", + "ka p", + "a ma", + "ı l", + "lar ın", + "b iz", + "tı r", + "o y", + "an cak", + "d oğ", + "ç a", + "b ana", + "ş im", + "baş la", + "l ü", + "ma dı", + "ben i", + "t ir", + "y ük", + "lı k", + "be ş", + "b el", + "b er", + "m er", + "na sıl", + "tı k", + "k e", + "t ür", + "a v", + ". .", + "d aki", + "p ar", + "t er", + "ce ğ", + "t en", + "z ı", + "iy i", + "d ok", + "ben im", + "c ağ", + "n er", + "y en", + "ş u", + "me z", + "düş ün", + "ken di", + "şim di", + "y ol", + "y u", + "de v", + "is te", + "s ek", + "ma m", + "s öyle", + "di k", + "t o", + "k ur", + "oldu ğ", + "s ını", + "t ar", + "bil iyor", + "k an", + "y al", + "m eye", + "mu ş", + "f a", + "ka ç", + "bil e", + "iy e", + "t ü", + "e f", + "tı m", + "ev et", + "ç o", + "y et", + "g en", + "bura da", + "t im", + "bir az", + "es i", + "k or", + "doğ ru", + "in in", + "kı z", + "di ye", + "d ör", + "et ti", + "on un", + "is ti", + "ğ i", + "h e", + "s ana", + "ü ş", + "ar ka", + "hay ır", + "kar şı", + "h ar", + "il e", + "h ak", + "ı yor", + "ne den", + "s ev", + "sı z", + "ço cu", + "me m", + "ç alı", + "ol ur", + "b ır", + "g ir", + "is e", + "i h", + "c an", + "k ır", + "d ön", + "b öyle", + "sen i", + "! \"", + "al t", + "dör t", + "s öy", + "o ş", + "mu sun", + "la ş", + "h an", + "i p", + "ka y", + "h em", + "bü yük", + "a ç", + "bır ak", + "mi sin", + "s öz", + "u l", + "değ iş", + "ün ü", + "g ül", + "k ö", + "kar ı", + "ta mam", + "ol u", + "r ar", + "yen i", + "la m", + "mış tı", + "ya ş", + "al a", + "in iz", + "kad ın", + "bun un", + "m ey", + "al tı", + "y i", + "s o", + "in den", + "sen in", + "ya t", + "to p", + "s er", + "is i", + "d ün", + "s es", + "hiç bir", + "y on", + "d ın", + "t ün", + "baş ka", + "a s", + "he p", + "i t", + "ir mi", + "dev am", + "ola cak", + "ar tık", + "r e", + "dur um", + "im iz", + "üz el", + "ler ini", + "sa ğ", + "p ro", + "ger ek", + "y irmi", + "ş ek", + "ba ğ", + "me di", + "lar a", + "a h", + "t ur", + "y ür", + "ma sı", + "ka tı", + "de di", + "g ü", + "sor un", + "el i", + "ün e", + "mı z", + "yap ı", + "m il", + "ğ ını", + "t ara", + "m en", + "ha t", + "var dı", + "m et", + "konu ş", + "ar ak", + "lar ak", + "çocu k", + "bü tün", + "l ey", + "d ür", + "g üzel", + "ay ı", + "yap a", + "n ı", + "ay r", + "ö ne", + "yordu m", + "b an", + "i̇ ş", + "du m", + "un a", + "on a", + "yor lar", + "lar ını", + "çı kar", + "z an", + "se ç", + "l iyor", + "t ak", + "şı k", + "tek rar", + "a ş", + "e ş", + "miş ti", + "f ar", + "k in", + "im i", + "i f", + "e ğ", + "gi di", + "le ş", + "başla dı", + "gi de", + "ot ur", + "d de", + "ın dan", + "üz er", + "ın ın", + "n ız", + "u y", + "ye di", + "ka t", + "o larak", + "la dı", + "yal nız", + "ba h", + "iy et", + "m al", + "s ak", + "a çık", + "sın da", + ".. .", + "in san", + "ay nı", + "e der", + "is tan", + "uz un", + "sa h", + "d o", + "g eri", + "er ek", + "ol an", + "ger çek", + "f en", + "al an", + "dı ş", + "alı k", + "far k", + "ü st", + "sa de", + "r i", + "k iş", + "l dı", + "z or", + "et ir", + "her kes", + "s al", + "ö mer", + "s el", + "un da", + "ha f", + "bun a", + "y dı", + "pek i", + "ada m", + "ha z", + "sın a", + "kap ı", + "gör üş", + "sade ce", + "al dı", + "gel di", + "i e", + "n ie", + "n a", + "r z", + "s z", + "c z", + "p o", + "s t", + "c h", + "i ę", + "d z", + "n i", + "a ł", + "r a", + "j e", + "r o", + "d o", + "s ię", + "z a", + "g o", + "e m", + "w i", + "c i", + "rz e", + "k o", + "l e", + "l i", + "w a", + "t o", + "k a", + "m i", + "ż e", + "t a", + "w ie", + "b y", + "m o", + "w y", + "rz y", + "ł a", + "j a", + "n o", + "ł o", + "w o", + "p a", + "m a", + "t e", + "t y", + "n y", + "k i", + "d a", + "n e", + "dz ie", + "dz i", + "cz y", + "c ie", + "m y", + "p rze", + "d y", + "o d", + "l a", + "k ie", + "r y", + "st a", + "j ą", + "ó w", + "c e", + "p rzy", + "c o", + "k u", + "m ie", + "sz y", + "cz e", + "r e", + "b a", + "s i", + "b ie", + "m u", + "w e", + "c y", + "ni a", + "ś ci", + "sz e", + "je st", + "k t", + "s a", + "b o", + "t u", + "ż y", + "n ą", + "b i", + "r u", + "a le", + "kt ó", + "p ra", + "ał a", + "m nie", + "p ie", + "ł y", + "cz a", + "ja k", + "ro z", + "r ó", + "l u", + "z na", + "g a", + "ra z", + "ł u", + "ta k", + "j u", + "p i", + "ś ć", + "s o", + "wi a", + "m ó", + "ch o", + "w szy", + "p e", + "s po", + "c a", + "g dy", + "w ał", + "w ię", + "d e", + "b e", + "p ro", + "ł em", + "j ę", + "s k", + "z e", + "l o", + "g i", + "r ę", + "do b", + "d u", + "ju ż", + "st o", + "b ę", + "ał em", + "sz a", + "m e", + "po d", + "d la", + "pa n", + "n ę", + "z o", + "mo że", + "ś li", + "s ie", + "ał o", + "t em", + "l ko", + "ny ch", + "po wie", + "c ię", + "s u", + "ty lko", + "i n", + "b u", + "na j", + "ch a", + "te go", + "p u", + "s ki", + "ne go", + "wszy st", + "sz cze", + "je d", + "je j", + "t wo", + "ą d", + "ś my", + "cz ę", + "wa ć", + "je go", + "ż a", + "i m", + "s y", + "pra w", + "ty m", + "któ ry", + "ał y", + "t rze", + "nie j", + "s e", + "ny m", + "i ch", + "o b", + ". .", + "g ło", + "ją c", + "mó wi", + "s ka", + "o n", + "ne j", + "s łu", + "w ła", + "bę dzie", + "d ę", + "p ó", + "be z", + "ni c", + "p ła", + "ś cie", + "mi a", + "s ą", + "t rzy", + "kie m", + "by ł", + "mo g", + "ro bi", + "ta m", + "c u", + "te n", + "m ię", + "z y", + "pe w", + "ci a", + "my ś", + "prze d", + "s ko", + "n u", + "któ re", + "a l", + "l ę", + "w sze", + "ą c", + "by ło", + "so bie", + "p y", + "ci ą", + "ba r", + "je szcze", + "h a", + "t ę", + "b ra", + "cza s", + "sz ę", + "g ł", + "k ę", + "ma r", + "cz u", + "prze z", + "f i", + "s ło", + "w z", + "k to", + "k ów", + "cz o", + "li śmy", + "st ra", + "wię c", + "r ą", + "ma m", + "w ó", + "rz a", + "g ro", + "no ści", + "f a", + "we t", + "ną ł", + "ś mie", + "na wet", + "mu si", + "s wo", + "te j", + "w ą", + "w u", + "wi ą", + "ni u", + "cz ą", + "b li", + "dz o", + "s kie", + "n em", + "je śli", + "cze go", + "ch y", + "d ł", + "ty ch", + "by m", + "ż o", + "e ś", + "si ą", + "kie dy", + "na s", + "w ró", + "dz e", + "d ro", + "t ra", + "r ów", + "pa ni", + "z ie", + "ku l", + "na d", + "ch wi", + "ni m", + "t ro", + "by ć", + "cho dzi", + "ni o", + "dob rze", + "te raz", + "wo kul", + "co ś", + "k ł", + "pie r", + "h e", + "g dzie", + "dz y", + "p ię", + "d ź", + "k ą", + "g ó", + "z da", + "ch ce", + "st ę", + "o r", + "ś wia", + "wszyst ko", + "st ro", + "pe ł", + "wie m", + "wie l", + "ka ż", + "ki m", + "rz u", + "s ły", + "jed na", + "z u", + "myś l", + "mó j", + "g u", + "wa r", + "jest em", + "ó ż", + "mie j", + "mo ż", + "k ła", + "re sz", + "d łu", + "st wo", + "n ię", + "ma sz", + "że by", + "nie m", + "ja kie", + "st y", + "ni ą", + "we j", + "o j", + "g ra", + "s ła", + "no ść", + "z ło", + "sz czę", + ".. .", + "r i", + "le j", + "we go", + "c ał", + "dzi ał", + "ki ch", + "dz a", + "dz ię", + "o czy", + "zo sta", + "cz ło", + "na m", + "ki l", + "o na", + "sz u", + "w ę", + "pa r", + "mi ał", + "st rze", + "ce j", + "e j", + "zna j", + "da ć", + "miej s", + "k ró", + "k ry", + "bar dzo", + "si a", + "z i", + "ś nie", + "l ą", + "g ie", + "cie bie", + "d ni", + "st u", + "po trze", + "wokul ski", + "u wa", + "u mie", + "jedna k", + "k ra", + "wró ci", + "czło wie", + "czy ć", + "by ła", + "że li", + "m ę", + "c ę", + "z robi", + "mog ę", + "pro wa", + "r em", + "nie ch", + "cz nie", + "k ro", + "t ą", + "ch ci", + "b ro", + "dzie ć", + "sz ą", + "pa d", + "t rz", + "t ru", + "je m", + "a ni", + "t ów", + "a r", + "d ru", + "ta j", + "rze kł", + "sa m", + "st e", + "nie go", + "ta kie", + "w ała", + "to wa", + "ka pła", + "wi dzi", + "po dob", + "dz ę", + "t ał", + "stę p", + "b ą", + "po ko", + "w em", + "g ę", + "a by", + "g e", + "al bo", + "s pra", + "z no", + "de n", + "s mo", + "je sz", + "k się", + "jest eś", + "po z", + "ni gdy", + "k sią", + "c óż", + "w s", + "po w", + "t ka", + "ś wie", + "sz ka", + "sa mo", + "s ł", + "rz ę", + "na le", + "chce sz", + "ni k", + "p ę", + "chy ba", + "cią g", + "ją cy", + "wo j", + "na sze", + "mnie j", + "wię cej", + "z wy", + "o sta", + "f e", + "wa ż", + "h o", + "se r", + "śmie r", + "wie r", + "dz ą", + "za ś", + "gdy by", + "ja ki", + "wo l", + "wi n", + "d ą", + "ści a", + "roz ma", + "wa l", + "pa nie", + "sta r", + "ka z", + "je żeli", + "d em", + "w ra", + "ko ń", + "sie bie", + "zno wu", + "p ró", + "cz em", + "st wa", + "i sto", + "pó ł", + "d ał", + "ko bie", + "ała m", + "wy ch", + "ce sa", + "ni ch", + "za wsze", + "dzi ć", + "te ż", + "le pie", + "pro szę", + "k re", + "t wa", + "o t", + "ł ą", + "ch u", + "c ą", + "p rz", + "ł e", + "sze dł", + "od powie", + "my śli", + "ś wią", + "e n", + "e r", + "d e", + "a n", + "e t", + "i j", + "i n", + "e l", + "a a", + "s t", + "o r", + "g e", + "i s", + "a t", + "i e", + "c h", + "o n", + "e en", + "h et", + "i t", + "v er", + "aa r", + "a l", + "o or", + "g en", + "v an", + "o p", + "d en", + "h e", + "o m", + "t e", + "w e", + "i k", + "r e", + "z e", + "ij n", + "d at", + "b e", + "d er", + "in g", + "o e", + "ij k", + "a an", + "ch t", + "v oor", + "l e", + "i et", + "r o", + "m o", + "k en", + "z ijn", + "m en", + "i g", + "j e", + "n iet", + "a r", + "o o", + "i d", + "u n", + "i l", + "s ch", + "mo et", + "st e", + "u r", + "o l", + "he b", + "u it", + "g el", + "w ij", + "a s", + "m e", + "t en", + "w or", + "o u", + "v en", + "l en", + "aa t", + "d it", + "m et", + "r a", + "b en", + "s p", + "o ver", + "d ie", + "n o", + "w er", + "l ijk", + "f t", + "s l", + "an d", + "v e", + "t er", + "i er", + "i en", + "t o", + "d aar", + "g r", + "b el", + "de ze", + "d u", + "a g", + "k an", + "wor den", + "in gen", + "moet en", + "n en", + "on der", + "heb ben", + "r u", + "oo k", + "s en", + "c t", + "k t", + "no g", + "aa l", + "w as", + "u l", + "e er", + "b ij", + "m ijn", + "p ro", + "v ol", + "d o", + "k om", + "at ie", + "e ft", + "k el", + "al s", + "r ij", + "he id", + "a f", + "st el", + "m aar", + "a p", + "we e", + "a d", + "he eft", + "w aar", + "i cht", + "d an", + "er en", + "n e", + "w el", + "w at", + "w il", + "a cht", + "aa g", + "ge b", + "c on", + "z o", + "k e", + "b et", + "h ij", + "d ig", + "k un", + "u w", + "d t", + "d oor", + "t ij", + "a m", + "an g", + "on d", + "er s", + "is ch", + "ge en", + "i ge", + "ge v", + "ve el", + "n u", + "m a", + "on s", + "o f", + "b l", + "n aar", + "g ro", + "p l", + "an der", + "at en", + "kun nen", + "e cht", + "h ier", + "g oe", + "an t", + "u s", + "t wee", + "on t", + "de lijk", + "el e", + "u ur", + "al le", + "t oe", + "me er", + "i st", + "n a", + "n ie", + "on ze", + "l o", + "i m", + "p en", + "h ad", + "tij d", + "h oe", + "to t", + "z ou", + "a k", + "aa k", + "a men", + "d r", + "w oor", + "s e", + "wor dt", + "o t", + "gel ijk", + "g aan", + "i c", + "g er", + "k er", + "el d", + "e m", + "h ou", + "de l", + "z en", + "z el", + "te gen", + "b o", + "kom en", + "c om", + "i gen", + "e it", + "wer k", + "goe d", + "z al", + "z ij", + "sl ag", + "e s", + "z ien", + "a st", + "echt er", + "it ie", + "t ie", + "el ijk", + "m is", + "isch e", + "bel an", + "h aar", + "i ch", + "b er", + "h an", + "v r", + "al e", + "c i", + "gr ijk", + "in d", + "do en", + "l and", + "belan grijk", + "p un", + "op en", + "ct ie", + "zel f", + "m ij", + "it eit", + "ste m", + "me e", + "ar en", + "al l", + "b r", + "re cht", + "d ien", + "h u", + "g aat", + "pro b", + "m oe", + "p er", + "a u", + "ul len", + "z ich", + "daar om", + "or m", + "k l", + "v o", + "en t", + "st aat", + "z it", + "du i", + "n at", + "du s", + "d s", + "ver slag", + "kel ijk", + "prob le", + "w et", + "ge m", + "c r", + "i on", + "p r", + "sch ap", + "g d", + "h un", + "z a", + "er d", + "z et", + "st aan", + "st r", + "m aal", + "in der", + "e id", + "st en", + "p ar", + "k ken", + "ge d", + "z ullen", + "re s", + "men sen", + "j aar", + "re gel", + "ie der", + "vol gen", + "ge ven", + "e ven", + "l u", + "bl ij", + "i ë", + "k o", + "u we", + "m an", + "ma ken", + "l ie", + "g a", + "oe k", + "nie uwe", + "b aar", + "h o", + "h er", + "in ter", + "ander e", + "ru ik", + "s u", + "a gen", + "or t", + "m er", + "ou w", + "st er", + "wil len", + "aa kt", + "h oo", + "an den", + "f f", + "l ig", + "t re", + "s amen", + "ze er", + "dui delijk", + "ant woor", + "he el", + "men t", + "pun t", + "hou den", + "we g", + "vr aag", + "gel e", + "een s", + "be sch", + "om en", + "er g", + "do el", + "d ag", + "sp e", + "ur en", + "ing s", + "or en", + "l ang", + "de len", + "m ar", + "ste un", + "in nen", + "p ol", + "o on", + "i de", + "s n", + "s ie", + "r icht", + "z onder", + "no dig", + "all een", + "m id", + "ra gen", + "iet s", + "ver sch", + "geb ruik", + "st u", + "ro uw", + "stel len", + "be g", + "men ten", + "v in", + "eer ste", + "l aat", + "gro ot", + "oo d", + "to ch", + "l aten", + "aar d", + "s le", + "de el", + "st and", + "pl aat", + "re e", + "bet re", + "d i", + "l id", + "uit en", + "ra cht", + "bel eid", + "g et", + "ar t", + "st ie", + "st aten", + "g gen", + "re ken", + "e in", + "al en", + "m ing", + "mo gelijk", + "gro te", + "al tijd", + "z or", + "en kel", + "w ik", + "pol itie", + "e igen", + "el k", + "han del", + "g t", + "k we", + "m aat", + "el en", + "i p", + "v rij", + "s om", + "je s", + "aa m", + "hu is", + "v al", + "we er", + "lid staten", + "k ing", + "k le", + "be d", + "gev al", + "stel l", + "a i", + "wik kel", + "kwe stie", + "t al", + "ste e", + "a b", + "h el", + "kom st", + "p as", + "s s", + "it u", + "i den", + "eer d", + "m in", + "c e", + "p o", + "twee de", + "proble em", + "w aren", + "us sen", + "sn el", + "t ig", + "ge w", + "j u", + "ul t", + "ne men", + "com mis", + "versch il", + "k on", + "z oek", + "k rij", + "gr aag", + "den k", + "l anden", + "re den", + "be sl", + "oe g", + "bet er", + "he den", + "m ag", + "p e", + "bo ven", + "a c", + "con t", + "f d", + "h ele", + "k r", + "v ier", + "w in", + "ge z", + "k w", + "m il", + "v or", + "he m", + "ra m", + "aa s", + "ont wikkel", + "dr ie", + "v aak", + "plaat s", + "l a", + "g ang", + "ij f", + "f in", + "nat uur", + "t ussen", + "u g", + "in e", + "d a", + "b at", + "kom t", + "w acht", + "aa d", + "u t", + "é n", + "acht er", + "geb ie", + "ver k", + "lig t", + "c es", + "nie uw", + "van d", + "s t", + "n í", + "j e", + "p o", + "c h", + "r o", + "n a", + "s e", + "t o", + "n e", + "l e", + "k o", + "l a", + "d o", + "r a", + "n o", + "t e", + "h o", + "n ě", + "v a", + "l i", + "l o", + "ř e", + "c e", + "d e", + "v e", + "b y", + "n i", + "s k", + "t a", + "n á", + "z a", + "p ro", + "v o", + "v ě", + "m e", + "v á", + "s o", + "k a", + "r á", + "v y", + "z e", + "m i", + "p a", + "t i", + "st a", + "m ě", + "n é", + "ř i", + "ř í", + "m o", + "ž e", + "m a", + "j í", + "v ý", + "j i", + "d ě", + "r e", + "d a", + "k u", + "j a", + "c i", + "r u", + "č e", + "o b", + "t ě", + "m u", + "k y", + "d i", + "š e", + "k é", + "š í", + "t u", + "v i", + "p ře", + "v í", + "s i", + "n ý", + "o d", + "so u", + "v é", + "n y", + "r i", + "d y", + "b u", + "b o", + "t y", + "l á", + "l u", + "n u", + "ž i", + "m á", + "st i", + "c í", + "z á", + "p ra", + "sk é", + "m í", + "c o", + "d u", + "d á", + "by l", + "st o", + "s a", + "t í", + "je d", + "p ří", + "p ři", + "t é", + "s í", + "č i", + "v ní", + "č a", + "d í", + "z i", + "st u", + "p e", + "b a", + "d ní", + "ro z", + "va l", + "l í", + "s po", + "k á", + "b e", + "p i", + "no u", + "ta k", + "st e", + "r y", + "l é", + "vě t", + "se m", + "p ě", + "ko n", + "ne j", + "l y", + "ko u", + "ý ch", + "b ě", + "p r", + "f i", + "p rá", + "a le", + "ja ko", + "po d", + "ž í", + "z í", + "j sou", + "j sem", + "ch o", + "l ní", + "c ké", + "t á", + "m y", + "a k", + "h u", + "va t", + "pře d", + "h la", + "k e", + "st á", + "č í", + "š i", + "s le", + "k la", + "š tě", + "lo u", + "m ů", + "z na", + "ch á", + "o r", + "p ů", + "h a", + "b i", + "ta ké", + "d ů", + "no st", + "t ře", + "te r", + "p u", + "i n", + "v r", + "ve l", + "sk u", + "v še", + "t ní", + "do b", + "by la", + "č ní", + "ja k", + "v u", + "je ho", + "b ý", + "vá ní", + "ný ch", + "po u", + "te n", + "t ři", + "v z", + "st ře", + "d va", + "h le", + "č á", + "no sti", + "c k", + "v š", + "vo u", + "s u", + "h e", + "h ra", + "je n", + "s y", + "da l", + "po z", + "s lo", + "te l", + "d ru", + "de n", + "vš ak", + "g i", + "k dy", + "by lo", + "bu de", + "st ra", + "j ší", + "m é", + "me n", + "vý ch", + "ní m", + "s m", + "ko li", + "r ů", + "t ra", + "mů že", + "ne ní", + "ho d", + "b í", + "do u", + "sk a", + "t ý", + "st ě", + "u je", + "s á", + "pě t", + "ne s", + "k rá", + "to m", + "st ví", + "v ně", + "se d", + "s vé", + "p í", + "z o", + "mu sí", + "u ž", + "tí m", + "jí cí", + "jed no", + "t r", + "ča s", + "e v", + "č ty", + "sk ý", + "ni c", + "ev ro", + "to ho", + "h y", + "k ter", + "r ní", + "st í", + "s vě", + "pa k", + "vše ch", + "k ů", + "n g", + "á d", + "chá zí", + "a ni", + "a r", + "jed na", + "bý t", + "t ro", + "k ra", + "pr vní", + "m no", + "ské ho", + "p á", + "p la", + "le m", + "ne bo", + "ke m", + "st ro", + "s la", + "né ho", + "z de", + "dal ší", + "ř a", + "čty ři", + "h rá", + "dru h", + "l ně", + "v la", + "sk ých", + "š ko", + "pů so", + "pro to", + "v ů", + "sk á", + "ve n", + "še st", + "d ně", + "je ště", + "me zi", + "te k", + "s ko", + "ch a", + "ně koli", + "be z", + "g ra", + "ji ž", + "č ně", + "j á", + "s lu", + "z ná", + "ve r", + "sed m", + "k ro", + "ta m", + "a no", + "v lá", + "o sm", + "byl y", + "vá m", + "ck ý", + "te ch", + "dě ji", + "vel mi", + "le ži", + "va la", + "l ý", + "t vo", + "spo le", + "ch u", + "stu p", + "mo ž", + "evro p", + "g e", + "sta l", + "j de", + "ch y", + "ro di", + "je jí", + "po li", + "de vět", + "s me", + "a ž", + "té to", + "re m", + "d é", + "f or", + "u ni", + "f o", + "ten to", + "a u", + "ka ž", + "nu la", + "na d", + "by ch", + "mo c", + "sto u", + "e x", + "le n", + "k do", + "z d", + "pra co", + "to mu", + "ný m", + "ži vo", + "ze m", + "f e", + "f u", + "ná sle", + "j o", + "sk y", + "ji ch", + "h á", + "mě l", + "dě la", + "j sme", + "p re", + "ni ce", + "ste j", + "ne m", + "st ní", + "he m", + "ná ro", + "z u", + "b li", + "ni t", + "pa r", + "a l", + "poz ději", + "ta ko", + "n ce", + "če r", + "ší m", + "ně co", + "vá l", + "ře j", + "krá t", + "á lní", + "u r", + ". .", + "a si", + "kter é", + "sta v", + "ma jí", + "my s", + "do bě", + "s ně", + "ce n", + "z y", + "z ku", + "t ů", + "ch od", + "s pě", + "je jich", + "sou čas", + "d r", + "va li", + "ri e", + "k te", + "pr ů", + "ze ní", + "pa t", + "a n", + "po tře", + "de m", + "d nes", + "ze mí", + "sa mo", + "zna m", + "b ra", + "má m", + "te dy", + "g o", + "hla vní", + "pou ží", + "b ní", + "ve de", + "le p", + "je k", + "pra v", + "poli ti", + "d ne", + "je m", + "le t", + "če ní", + "pro b", + "ne ž", + "dě l", + "fi l", + "č o", + "cí ch", + "st é", + "d lou", + "h i", + "a by", + "to u", + "několi k", + "d la", + "vy u", + "vi t", + "ho u", + "ck ých", + "no vé", + "či n", + "st y", + "dě lá", + "k ý", + "ob la", + "pod le", + "ra n", + "dů leži", + "ta to", + "po ku", + "ko ne", + "d ý", + "d vě", + "ž ád", + "nou t", + "t ku", + "t vr", + "cké ho", + "ro v", + "r é", + "te le", + "p sa", + "s vět", + "ti vní", + "do sta", + "te m", + "še l", + "druh é", + "s kou", + "ž o", + "jed ná", + "vý znam", + "prob lé", + "pu bli", + "vá n", + "od po", + "pod po", + "d le", + "ja ké", + "še ní", + "ví m", + "bě hem", + "na chází", + "s lou", + "pou ze", + "o tá", + "p lo", + "to vé", + "vět ši", + "ko mi", + "va jí", + "ty to", + "zá pa", + "z mě", + "mo h", + "ví ce", + "spole č", + "au to", + "pro ti", + "st ru", + "dě t", + "chá ze", + "že l", + "с т", + "е н", + "н о", + "н а", + "п р", + "т о", + "п о", + "р а", + "г о", + "к о", + "н е", + "в о", + "в а", + "е т", + "е р", + "н и", + "е л", + "и т", + "н ы", + "з а", + "р о", + "ен и", + "к а", + "л и", + "е м", + "д а", + "о б", + "л а", + "д о", + "с я", + "т ь", + "о т", + "л о", + "л ь", + "е д", + "с о", + "м и", + "р е", + "м о", + "ц и", + "пр о", + "т а", + "э то", + "к и", + "р у", + "пр и", + "т и", + "с е", + "ст а", + "в ы", + "м ы", + "в и", + "б ы", + "м а", + "е с", + "л я", + "ст и", + "л е", + "ч то", + "м е", + "р и", + "ч а", + "о д", + "е й", + "ел ь", + "ени я", + "г а", + "н у", + "с и", + "п а", + "ра з", + "б о", + "ст о", + "с у", + "с а", + "д у", + "е го", + "е ст", + "и н", + "ит ь", + "и з", + "ж е", + "м у", + "п ер", + "по д", + "ени е", + "с ь", + "к у", + "пр ед", + "но го", + "ны х", + "в ер", + "т е", + "но й", + "ци и", + "д е", + "р ы", + "д ел", + "л ю", + "в е", + "о н", + "м ен", + "г и", + "н я", + "б у", + "пр а", + "в се", + "ет ся", + "ст ь", + "ж а", + "до л", + "ж и", + "б е", + "ко н", + "с л", + "ш и", + "д и", + "ст в", + "с ко", + "ны е", + "ч и", + "ю т", + "д ер", + "ст ра", + "т ы", + "х од", + "щ и", + "з о", + "з на", + "но сти", + "ч ес", + "в ля", + "ва ть", + "о р", + "по л", + "в ет", + "та к", + "ш а", + "т у", + "с во", + "пр е", + "о на", + "ит ель", + "ны й", + "с ло", + "ка к", + "в л", + "но сть", + "х о", + "мо ж", + "п е", + "д ля", + "ни я", + "но е", + "ра с", + "дол ж", + "да р", + "т ель", + "с ка", + "п у", + "ст во", + "ко то", + "ра б", + "е е", + "ро д", + "э ти", + "с об", + "о ру", + "ж ен", + "ны м", + "ит и", + "ни е", + "ко м", + "д ет", + "ст у", + "г у", + "п и", + "ме ж", + "ени ю", + "т ер", + "раб от", + "во з", + "ци я", + "ко й", + "щ ест", + "г ра", + "з и", + "р я", + "меж ду", + "ст ва", + "в с", + "ел о", + "ш е", + "м ер", + "б а", + "з ы", + "л у", + "а ль", + "д ей", + "г ла", + "на род", + "к ти", + "пред ста", + "л ся", + "я вля", + "с ки", + "но в", + "ед ин", + "ро в", + "и с", + "ни ма", + "р ем", + "ход и", + "так же", + "д ру", + "а ть", + "сл ед", + "го во", + "на я", + "ю щи", + "ен ь", + "кото ры", + "х от", + "в у", + "и х", + "ем у", + "ч ит", + "ва ж", + "ор га", + "чес ки", + "щ е", + "к е", + "х а", + "по с", + "то м", + "бо ль", + "м не", + "па с", + "об ъ", + "пра в", + "кон ф", + "сл у", + "под дер", + "ст ви", + "на ш", + "ль ко", + "сто я", + "ну ю", + "л ем", + "ен ных", + "к ра", + "д ы", + "между народ", + "г да", + "не об", + "го су", + "ств у", + "ени и", + "госу дар", + "к то", + "и м", + "ч ест", + "р ет", + "во про", + "л ен", + "ел и", + "ро ва", + "ци й", + "на м", + "это й", + "ж ения", + "необ ходи", + "мен я", + "бы ло", + "си ли", + "ф и", + "в я", + "ш ь", + "это го", + "о ни", + "орга ни", + "бе зо", + "пр об", + "и ме", + "ре ш", + "б и", + "безо пас", + "ют ся", + "о ста", + "ен но", + "го д", + "ел а", + "предста в", + "ть ся", + "сло во", + "органи за", + "долж ны", + "это м", + "б ла", + "ч е", + "ч у", + "бла го", + "это му", + "в рем", + "с пе", + "но м", + "ени й", + "с по", + "на с", + "не т", + "з у", + "в ед", + "е ще", + "ска за", + "се й", + "ер ен", + "да н", + "са м", + "ел я", + "ра н", + "зы ва", + "явля ется", + "бу дет", + "кти в", + "т ре", + "дел е", + "м от", + "конф ерен", + "ла сь", + "ча с", + "сто ро", + "ко го", + "е з", + "не й", + "о с", + "ли сь", + "раз ору", + "пер е", + "с си", + "ны ми", + "про ц", + "го ло", + "ч ело", + "бо ле", + "чело ве", + "с ер", + "п л", + "ч ет", + "стра н", + "п я", + "бы л", + "к ла", + "то в", + "ж д", + "дел а", + "е ра", + "у же", + "со вет", + "г ен", + "безопас ности", + "ц а", + "се да", + "по з", + "от вет", + "проб лем", + "на ко", + "т ем", + "до ста", + "п ы", + "щ а", + "во й", + "су щест", + "необходи мо", + "бы ть", + "мож ет", + "д ем", + "что бы", + "е к", + "ч ер", + "у сили", + "ре с", + "ру д", + "един енных", + "д об", + "до сти", + "ств ен", + "я дер", + "год ня", + "ка за", + "се годня", + "сей час", + "то лько", + "во д", + "ес ь", + "м ного", + "бу ду", + "е в", + "ест ь", + "т ри", + "об щест", + ". .", + "я вл", + "вы сту", + "р ед", + "с чит", + "с ит", + "деле га", + "ло ж", + "это т", + "ф ор", + "к лю", + "воз мож", + "ва ния", + "б ли", + "и ли", + "в з", + "на ций", + "ско го", + "при ня", + "п ла", + "о ч", + "ить ся", + "ст е", + "на ши", + "которы е", + "а р", + "име ет", + "с от", + "зна ч", + "пер ь", + "след у", + "ен ы", + "та ки", + "объ единенных", + "ст ро", + "те перь", + "б ле", + "благо дар", + "раз в", + "а н", + "жи ва", + "оч ень", + "я т", + "бе з", + "об ес", + "г ро", + "ло сь", + "с ы", + "организа ции", + "ч лен", + "то го", + "она ль", + "ж да", + "все х", + "с вя", + "боле е", + "со в", + "ко гда", + "во т", + "к ре", + "к ры", + "по этому", + "во ль", + "о й", + "ген ера", + "ч ем", + "л ы", + "пол ити", + "в ен", + "конферен ции", + "проц ес", + "б я", + "ит е", + "от но", + "разв ити", + "а ф", + "ю щ", + "в но", + "ми р", + "ни и", + "ка я", + "а с", + "итель но", + "в то", + "ени ем", + "генера ль", + "пр от", + "вс ем", + "сам бле", + "ас самбле", + "о м", + "з д", + "с мот", + "ре ги", + "ч его", + "од нако", + "усили я", + "дей стви", + "ч но", + "у ча", + "об раз", + "во с", + "э та", + "пер его", + "гово р", + "ва м", + "мо ло", + "врем я", + "д ь", + "хот ел", + "г ру", + "за явл", + "пре доста", + "по ль", + "не е", + "ре зо", + "перего во", + "резо лю", + "к рет", + "поддер ж", + "обес пе", + "не го", + "представ ит", + "на де", + "к ри", + "ч ь", + "про ек", + "л ет", + "дру ги", + "ا ل", + "َ ا", + "و َ", + "ّ َ", + "ِ ي", + "أ َ", + "ل َ", + "ن َ", + "ال ْ", + "ه ُ", + "ُ و", + "م ا", + "ن ْ", + "م ن", + "ع َ", + "ن ا", + "ل ا", + "م َ", + "ت َ", + "ف َ", + "أ ن", + "ل ي", + "م ِ", + "ا ن", + "ف ي", + "ر َ", + "ي َ", + "ه ِ", + "م ْ", + "ق َ", + "ب ِ", + "ل ى", + "ي ن", + "إ ِ", + "ل ِ", + "و ا", + "ك َ", + "ه ا", + "ً ا", + "م ُ", + "و ن", + "ال م", + "ب َ", + "ي ا", + "ذ ا", + "س ا", + "ال ل", + "م ي", + "ي ْ", + "ر ا", + "ر ي", + "ل ك", + "م َا", + "ن َّ", + "ل م", + "إ ن", + "س ت", + "و م", + "ّ َا", + "ل َا", + "ه م", + "ّ ِ", + "ك ُ", + "ك ان", + "س َ", + "ب ا", + "د ي", + "ح َ", + "ع ْ", + "ب ي", + "ال أ", + "و ل", + "ف ِي", + "ر ِ", + "د ا", + "مِ نْ", + "ُو نَ", + "و ْ", + "ه َا", + "ّ ُ", + "ال س", + "ال َ", + "ن ي", + "ل ْ", + "ت ُ", + "ه ل", + "ر ة", + "د َ", + "س ْ", + "ت ِ", + "ن َا", + "ر ْ", + "الل َّ", + "سا مي", + "ك ن", + "ك ل", + "ه َ", + "عَ لَ", + "ع لى", + "م ع", + "إ لى", + "ق د", + "ال ر", + "ُو ا", + "ي ر", + "ع ن", + "ي ُ", + "ن ِ", + "ب ْ", + "ال ح", + "هُ مْ", + "ق ا", + "ذ ه", + "ال ت", + "ِي نَ", + "ج َ", + "ه ذا", + "ع د", + "ال ع", + "د ْ", + "قَ الَ", + "ر ُ", + "ي م", + "ي ة", + "ن ُ", + "خ َ", + "ر ب", + "ال ك", + "و َا", + "أ نا", + "ة ِ", + "ال ن", + "ح د", + "ع ِ", + "ت ا", + "ه و", + "ف ا", + "ع ا", + "ال ش", + "ل ُ", + "ي ت", + "ذ َا", + "ي ع", + "ال ذ", + "ح ْ", + "ال ص", + "إِ نَّ", + "ج ا", + "ع لي", + "ك َا", + "ب ُ", + "ت ع", + "و ق", + "م ل", + "ل َّ", + "ي د", + "أ خ", + "ر ف", + "ت ي", + "ال ِ", + "ّ ا", + "ذ لك", + "أَ نْ", + "س ِ", + "ت وم", + "م ر", + "مَ نْ", + "ب ل", + "ال ق", + "الل ه", + "ِي َ", + "ك م", + "ذ َ", + "ع ل", + "ح ب", + "س ي", + "ع ُ", + "ال ج", + "ال د", + "ش َ", + "ت ك", + "ف ْ", + "ص َ", + "ل ل", + "د ِ", + "ب ر", + "ف ِ", + "ت ه", + "أ ع", + "ت ْ", + "ق ْ", + "الْ أَ", + "ئ ِ", + "عَ نْ", + "و ر", + "ح ا", + "ال َّ", + "م ت", + "ف ر", + "د ُ", + "ه نا", + "وَ أَ", + "ت ب", + "ة ُ", + "أ ي", + "س ب", + "ري د", + "و ج", + "كُ مْ", + "ح ِ", + "ك ْ", + "د ر", + "َا ء", + "ه ذه", + "ال ط", + "الْ مُ", + "د ة", + "ق ل", + "غ َ", + "ي وم", + "الَّ ذ", + "ك ر", + "ت ر", + "ك ِ", + "ك ي", + "عَلَ ى", + "رَ ب", + "ع ة", + "ق ُ", + "ج ْ", + "ف ض", + "ل ة", + "ه ْ", + "ر َا", + "وَ لَ", + "الْ مَ", + "أَ نَّ", + "ي َا", + "أ ُ", + "ش ي", + "اللَّ هُ", + "لَ ى", + "ق ِ", + "أ ت", + "عَلَ يْ", + "اللَّ هِ", + "ال ب", + "ض َ", + "ة ً", + "ق ي", + "ا ر", + "ب د", + "خ ْ", + "سْ تَ", + "ط َ", + "قَ دْ", + "ذه ب", + "أ م", + "ما ذا", + "وَ إِ", + "ة ٌ", + "و نَ", + "لي لى", + "و لا", + "ح ُ", + "ه ي", + "ص ل", + "ال خ", + "و د", + "لي س", + "ل دي", + "ق ال", + "كَا نَ", + "م َّ", + "ح ي", + "ت م", + "ل ن", + "وَ لَا", + "ب ع", + "يم كن", + "س ُ", + "ة َ", + "ح ت", + "ر ًا", + "ك ا", + "ش ا", + "هِ مْ", + "لَ هُ", + "ز َ", + "دا ً", + "م س", + "ك ث", + "الْ عَ", + "ج ِ", + "ص ْ", + "ف َا", + "ل ه", + "و ي", + "ع َا", + "هُ وَ", + "ب ِي", + "ب َا", + "أ س", + "ث َ", + "ل ِي", + "ر ض", + "الر َّ", + "لِ كَ", + "ت َّ", + "ف ُ", + "ق ة", + "ف عل", + "مِ ن", + "ال آ", + "ث ُ", + "س م", + "م َّا", + "بِ هِ", + "ت ق", + "خ ر", + "ل قد", + "خ ل", + "ش ر", + "أن ت", + "ل َّا", + "س ن", + "الس َّ", + "الذ ي", + "س َا", + "و ما", + "ز ل", + "و ب", + "أ ْ", + "إ ذا", + "ر ِي", + "ح ة", + "ن ِي", + "الْ حَ", + "وَ قَالَ", + "ب ه", + "ة ٍ", + "س أ", + "ر ٌ", + "ب ال", + "م ة", + "ش ْ", + "و ت", + "عن د", + "ف س", + "بَ عْ", + "ه ر", + "ق ط", + "أ ح", + "إن ه", + "و ع", + "ف ت", + "غ ا", + "هنا ك", + "ب ت", + "مِ نَ", + "س ر", + "ذَ لِكَ", + "ر س", + "حد ث", + "غ ْ", + "ّ ِي", + "ال إ", + "وَ يَ", + "ج ل", + "ا ست", + "ق ِي", + "ع ب", + "و س", + "ي ش", + "الَّذ ِينَ", + "تا ب", + "د ِي", + "ج ب", + "ك ون", + "ب ن", + "ال ث", + "لَ يْ", + "ب عد", + "وَ الْ", + "فَ أَ", + "ع م", + "هُ م", + "ت ن", + "ذ ْ", + "أ ص", + "أ ين", + "رَب ِّ", + "الذ ين", + "إِ ن", + "ب ين", + "ج ُ", + "عَلَيْ هِ", + "ح َا", + "ل و", + "ست ط", + "ظ ر", + "لَ مْ", + "ء ِ", + "كُ ل", + "ط ل", + "ت َا", + "ض ُ", + "كن ت", + "ل ًا", + "م ٌ", + "ق بل", + "ـ ـ", + "ذ ِ", + "قَ وْ", + "ص ِ", + "م ًا", + "كان ت", + "ص ا", + "ي ق", + "ال ف", + "ال نا", + "م ٍ", + "إِ نْ", + "ال نَّ", + "ج د", + "وَ مَا", + "ت ت", + "ب ح", + "م كان", + "كي ف", + "ّ ة", + "ال ا", + "ج َا", + "أ و", + "سا عد", + "ض ِ", + "إ لا", + "را ً", + "ق َا", + "ر أ", + "ع ت", + "أ حد", + "ه د", + "ض ا", + "ط ر", + "أ ق", + "ما ء", + "د َّ", + "ال با", + "م ُو", + "أَ وْ", + "ط ا", + "ق ُو", + "خ ِ", + "ت ل", + "ستط يع", + "د َا", + "الن َّا", + "إ لَى", + "وَ تَ", + "هَ ذَا", + "ب ة", + "علي ك", + "ج ر", + "ال من", + "ز ا", + "ر ٍ", + "د ع", + "ّ ًا", + "س ة", + "ثُ مَّ", + "شي ء", + "ال غ", + "ت ح", + "ر ُونَ", + "ال يوم", + "م ِي", + "ن ُوا", + "أ ر", + "تُ مْ", + "ع ر", + "ي ف", + "أ ب", + "د ًا", + "ص َا", + "الت َّ", + "أ ريد", + "ال ز", + "يَ وْ", + "إ لي", + "ج ي", + "يَ عْ", + "فض ل", + "ال إن", + "أن ه", + "n g", + "i 4", + "a n", + "s h", + "z h", + "i 2", + "ng 1", + "u 4", + "i 1", + "ng 2", + "d e", + "j i", + "a o", + "x i", + "u 3", + "de 5", + "e 4", + "i 3", + "ng 4", + "an 4", + "e n", + "u o", + "sh i4", + "an 2", + "u 2", + "c h", + "u 1", + "ng 3", + "a 1", + "an 1", + "e 2", + "a 4", + "e i4", + "o ng1", + "a i4", + "ao 4", + "h u", + "a ng1", + "l i", + "y o", + "an 3", + "w ei4", + "uo 2", + "n 1", + "en 2", + "ao 3", + "e 1", + "y u", + "q i", + "e ng2", + "zh o", + "a ng3", + "a ng4", + "a ng2", + "uo 4", + "m i", + "g e4", + "y i1", + "g uo2", + "e r", + "b i", + "a 3", + "h e2", + "e 3", + "y i2", + "d i4", + "zh ong1", + "b u4", + "g u", + "a i2", + "n 2", + "z ai4", + "sh i2", + "e ng1", + "r en2", + "o ng2", + "xi an4", + "y i", + "x u", + "n 4", + "l i4", + "en 4", + "y u2", + "e i2", + "yi2 ge4", + "o u4", + "e i3", + "d i", + "u i4", + "a 2", + "yo u3", + "ao 1", + "d a4", + "ch eng2", + "en 1", + "e ng4", + "y i4", + "s i1", + "zh i4", + "ji a1", + "yu an2", + "n i", + "t a1", + "de5 yi2ge4", + "k e1", + "sh u3", + "x i1", + "j i2", + "ao 2", + "t i", + "o u3", + "o ng4", + "xi a4", + "a i1", + "g ong1", + "zh i1", + "en 3", + "w ei2", + "j u", + "xu e2", + "q u1", + "zho u1", + "er 3", + "mi ng2", + "zho ng3", + "l i3", + "w u4", + "y i3", + "uo 1", + "e 5", + "j i4", + "xi ng2", + "ji an4", + "hu a4", + "y u3", + "uo 3", + "j i1", + "a i3", + "z uo4", + "h ou4", + "hu i4", + "e i1", + "ni an2", + "q i2", + "p i", + "d ao4", + "sh eng1", + "de 2", + "d ai4", + "u an2", + "zh e4", + "zh eng4", + "b en3", + "sh ang4", + "zh u3", + "b ei4", + "y e4", + "ch u1", + "zh an4", + "l e5", + "l ai2", + "sh i3", + "n an2", + "r en4", + "yo u2", + "k e4", + "b a1", + "f u4", + "d ui4", + "y a4", + "m ei3", + "z i4", + "xi n1", + "ji ng1", + "zh u", + "n 3", + "yo ng4", + "m u4", + "ji ao4", + "y e3", + "ji n4", + "bi an4", + "l u4", + "q i1", + "sh e4", + "xi ang1", + "o ng3", + "sh u4", + "d ong4", + "s uo3", + "gu an1", + "s an1", + "b o", + "t e4", + "d uo1", + "f u2", + "mi n2", + "l a1", + "zh i2", + "zh en4", + "o u1", + "w u3", + "m a3", + "i 5", + "z i5", + "j u4", + "er 4", + "y ao4", + "xia4 de5yi2ge4", + "s i4", + "t u2", + "sh an1", + "z ui4", + "ch u", + "yi n1", + "er 2", + "t ong2", + "d ong1", + "y u4", + "y an2", + "qi an2", + "shu3 xia4de5yi2ge4", + "ju n1", + "k e3", + "w en2", + "f a3", + "l uo2", + "zh u4", + "x i4", + "k ou3", + "b ei3", + "ji an1", + "f a1", + "di an4", + "ji ang1", + "wei4 yu2", + "xi ang4", + "zh i3", + "e ng3", + "f ang1", + "l an2", + "sh u", + "r i4", + "li an2", + "sh ou3", + "m o", + "qi u2", + "ji n1", + "h uo4", + "shu3xia4de5yi2ge4 zhong3", + "f en1", + "n ei4", + "g ai1", + "mei3 guo2", + "u n2", + "g e2", + "b ao3", + "qi ng1", + "g ao1", + "t ai2", + "d u", + "xi ao3", + "ji e2", + "ti an1", + "ch ang2", + "q uan2", + "li e4", + "h ai3", + "f ei1", + "t i3", + "ju e2", + "o u2", + "c i3", + "z u2", + "n i2", + "bi ao3", + "zhong1 guo2", + "d u4", + "yu e4", + "xi ng4", + "sh eng4", + "ch e1", + "d an1", + "ji e1", + "li n2", + "pi ng2", + "f u3", + "g u3", + "ji e4", + "w o", + "v 3", + "sh eng3", + "n a4", + "yu an4", + "zh ang3", + "gu an3", + "d ao3", + "z u3", + "di ng4", + "di an3", + "c eng2", + "ren2 kou3", + "t ai4", + "t ong1", + "g uo4", + "n eng2", + "ch ang3", + "hu a2", + "li u2", + "yi ng1", + "xi ao4", + "c i4", + "bian4 hua4", + "li ang3", + "g ong4", + "zho ng4", + "de5 yi1", + "s e4", + "k ai1", + "w ang2", + "ji u4", + "sh i1", + "sh ou4", + "m ei2", + "k u", + "s u", + "f eng1", + "z e2", + "tu2 shi4", + "t i2", + "q i4", + "ji u3", + "sh en1", + "zh e3", + "ren2kou3 bian4hua4", + "ren2kou3bian4hua4 tu2shi4", + "di4 qu1", + "y ang2", + "m en", + "men 5", + "l ong2", + "bi ng4", + "ch an3", + "zh u1", + "w ei3", + "w ai4", + "xi ng1", + "bo 1", + "b i3", + "t ang2", + "hu a1", + "bo 2", + "shu i3", + "sh u1", + "d ou1", + "s ai4", + "ch ao2", + "b i4", + "li ng2", + "l ei4", + "da4 xue2", + "f en4", + "shu3 de5", + "m u3", + "ji ao1", + "d ang1", + "ch eng1", + "t ong3", + "n v3", + "q i3", + "y an3", + "mi an4", + "l uo4", + "ji ng4", + "g e1", + "r u4", + "d an4", + "ri4 ben3", + "p u3", + "yu n4", + "hu ang2", + "wo 3", + "l v", + "h ai2", + "shi4 yi1", + "xi e1", + "yi ng3", + "w u2", + "sh en2", + "w ang3", + "gu ang3", + "li u4", + "s u4", + "shi4 zhen4", + "c an1", + "c ao3", + "xi a2", + "k a3", + "d a2", + "h u4", + "b an4", + "d ang3", + "h u2", + "z ong3", + "de ng3", + "de5yi2ge4 shi4zhen4", + "ch uan2", + "mo 4", + "zh ang1", + "b an1", + "mo 2", + "ch a2", + "c e4", + "zhu3 yao4", + "t ou2", + "j u2", + "shi4 wei4yu2", + "s a4", + "u n1", + "ke3 yi3", + "d u1", + "h an4", + "li ang4", + "sh a1", + "ji a3", + "z i1", + "lv 4", + "f u1", + "xi an1", + "x u4", + "gu ang1", + "m eng2", + "b ao4", + "yo u4", + "r ong2", + "zhi1 yi1", + "w ei1", + "m ao2", + "guo2 jia1", + "c ong2", + "g ou4", + "ti e3", + "zh en1", + "d u2", + "bi an1", + "c i2", + "q u3", + "f an4", + "xi ang3", + "m en2", + "j u1", + "h ong2", + "z i3", + "ta1 men5", + "ji 3", + "z ong1", + "zhou1 de5yi2ge4shi4zhen4", + "t uan2", + "ji ng3", + "gong1 si1", + "xi e4", + "l i2", + "li4 shi3", + "b ao1", + "g ang3", + "gu i1", + "zh eng1", + "zhi2 wu4", + "ta1 de5", + "pi n3", + "zhu an1", + "ch ong2", + "shi3 yong4", + "w a3", + "sh uo1", + "chu an1", + "l ei2", + "w an1", + "h uo2", + "q u", + "s u1", + "z ao3", + "g ai3", + "q u4", + "g u4", + "l u", + "x i2", + "h ang2", + "yi ng4", + "c un1", + "g en1", + "yi ng2", + "ti ng2", + "cheng2 shi4", + "ji ang3", + "li ng3", + "l un2", + "bu4 fen4", + "de ng1", + "xu an3", + "dong4 wu4", + "de2 guo2", + "xi an3", + "f an3", + "zh e5", + "h an2", + "h ao4", + "m i4", + "r an2", + "qi n1", + "ti ao2", + "zh an3", + "h i", + "k a", + "n o", + "t e", + "s u", + "s hi", + "t a", + "t o", + "n a", + "w a", + "o u", + "r u", + "n i", + "k u", + "k i", + "g a", + "d e", + "k o", + "m a", + "r e", + "r a", + "m o", + "t su", + "w o", + "e n", + "r i", + "s a", + "d a", + "s e", + "j i", + "h a", + "c hi", + "k e", + "te ki", + "m i", + "y ou", + "s h", + "s o", + "y o", + "y a", + "na i", + "t te", + "a ru", + "b a", + "u u", + "t ta", + "ka i", + "ka n", + "shi te", + "m e", + "d o", + "mo no", + "se i", + "r o", + "ko to", + "ka ra", + "shi ta", + "b u", + "m u", + "c h", + "su ru", + "k ou", + "g o", + "ma su", + "ta i", + "f u", + "k en", + "i u", + "g en", + "wa re", + "shi n", + "z u", + "a i", + "o n", + "o ku", + "g i", + "d ou", + "n e", + "y uu", + "i ru", + "i te", + "ji ko", + "de su", + "j u", + "ra re", + "sh u", + "b e", + "sh ou", + "s ha", + "se kai", + "s ou", + "k you", + "ma shita", + "s en", + "na ra", + "sa n", + "ke i", + "i ta", + "a ri", + "i tsu", + "ko no", + "j ou", + "na ka", + "ch ou", + "so re", + "g u", + "na ru", + "ga ku", + "re ba", + "g e", + "h o", + "i n", + "hi to", + "sa i", + "na n", + "da i", + "tsu ku", + "shi ki", + "sa re", + "na ku", + "p p", + "bu n", + "ju n", + "so no", + "ka ku", + "z ai", + "b i", + "to u", + "wa ta", + "sh uu", + "i i", + "te i", + "ka re", + "y u", + "shi i", + "ma de", + "sh o", + "a n", + "ke reba", + "shi ka", + "i chi", + "ha n", + "de ki", + "ni n", + "ware ware", + "na kereba", + "o ite", + "h ou", + "ya ku", + "ra i", + "mu jun", + "l e", + "yo ku", + "bu tsu", + "o o", + "ko n", + "o mo", + "ga e", + "nara nai", + "ta chi", + "z en", + "ch uu", + "kan gae", + "ta ra", + "to ki", + "ko ro", + "mujun teki", + "z e", + "na ga", + "ji n", + "shi ma", + "te n", + "i ki", + "i ku", + "no u", + "i masu", + "r ou", + "h on", + "ka e", + "t to", + "ko re", + "ta n", + "ki ta", + "i s", + "da tta", + "ji tsu", + "ma e", + "i e", + "me i", + "da n", + "h e", + "to ku", + "dou itsu", + "ri tsu", + "k yuu", + "h you", + "rare ta", + "kei sei", + "k kan", + "rare ru", + "m ou", + "do ko", + "r you", + "da ke", + "naka tta", + "so ko", + "ta be", + "e r", + "ha na", + "c o", + "fu ku", + "p a", + "so n", + "ya su", + "ch o", + "wata ku", + "ya ma", + "z a", + "k yo", + "gen zai", + "b oku", + "a ta", + "j a", + "ka wa", + "ma sen", + "j uu", + "ro n", + "b o", + "na tte", + "wataku shi", + "yo tte", + "ma i", + "g ou", + "ha i", + "mo n", + "ba n", + "ji shin", + "c a", + "re te", + "n en", + "o ka", + "ka gaku", + "na tta", + "p o", + "ka ru", + "na ri", + "m en", + "ma ta", + "e i", + "ku ru", + "ga i", + "ka ri", + "sha kai", + "kou i", + "yo ri", + "se tsu", + "j o", + "re ru", + "to koro", + "ju tsu", + "i on", + "sa ku", + "tta i", + "c ha", + "nin gen", + "n u", + "c e", + "ta me", + "kan kyou", + "de n", + "o oku", + "i ma", + "wata shi", + "tsuku ru", + "su gi", + "b en", + "ji bun", + "shi tsu", + "ke ru", + "ki n", + "ki shi", + "shika shi", + "mo to", + "ma ri", + "i tte", + "de shita", + "n de", + "ari masu", + "te r", + "z ou", + "ko e", + "ze ttai", + "kkan teki", + "h en", + "re kishi", + "deki ru", + "tsu ka", + "l a", + "i tta", + "o i", + "ko butsu", + "mi ru", + "sh oku", + "shi masu", + "gi jutsu", + "g you", + "jou shiki", + "a tta", + "ho do", + "ko ko", + "tsuku rareta", + "z oku", + "hi tei", + "ko ku", + "rekishi teki", + "ke te", + "o ri", + "i mi", + "ka ko", + "naga ra", + "ka karu", + "shu tai", + "ha ji", + "ma n", + "ta ku", + "ra n", + "douitsu teki", + "z o", + "me te", + "re i", + "tsu u", + "sare te", + "gen jitsu", + "p e", + "s t", + "ba i", + "na wa", + "ji kan", + "wa ru", + "r t", + "a tsu", + "so ku", + "koui teki", + "a ra", + "u ma", + "a no", + "i de", + "ka ta", + "te tsu", + "ga wa", + "ke do", + "re ta", + "mi n", + "sa you", + "tte ru", + "to ri", + "p u", + "ki mi", + "b ou", + "mu ra", + "sare ru", + "ma chi", + "k ya", + "o sa", + "kon na", + "a ku", + "a l", + "sare ta", + "i pp", + "shi ku", + "u chi", + "hito tsu", + "ha tara", + "tachi ba", + "shi ro", + "ka tachi", + "to mo", + "e te", + "me ru", + "ni chi", + "da re", + "ka tta", + "e ru", + "su ki", + "a ge", + "oo ki", + "ma ru", + "mo ku", + "o ko", + "kangae rareru", + "o to", + "tan ni", + "ta da", + "tai teki", + "mo tte", + "ki nou", + "shi nai", + "k ki", + "u e", + "ta ri", + "l i", + "ra nai", + "k kou", + "mi rai", + "pp on", + "go to", + "hi n", + "hi tsu", + "te ru", + "mo chi", + "ka tsu", + "re n", + "n yuu", + "su i", + "zu ka", + "tsu ite", + "no mi", + "su gu", + "ku da", + "tetsu gaku", + "i ka", + "ron ri", + "o ki", + "ni ppon", + "p er", + "shi mashita", + "chi shiki", + "cho kkanteki", + "su ko", + "t ion", + "ku u", + "a na", + "a rou", + "ka tte", + "ku ri", + "i nai", + "hyou gen", + "i shiki", + "do ku", + "a tte", + "a tara", + "to n", + "wa ri", + "ka o", + "sei san", + "hana shi", + "s i", + "ka ke", + "na ji", + "su nawa", + "sunawa chi", + "u go", + "su u", + "ba ra", + "le v", + "hi ro", + "i wa", + "be tsu", + "yo i", + "se ru", + "shite ru", + "rare te", + "to shi", + "se ki", + "tai ritsu", + "wa kara", + "to kyo", + "k ka", + "k yoku", + "u n", + "i ro", + "mi te", + "sa ki", + "kan ji", + "mi ta", + "su be", + "r yoku", + "ma tta", + "kuda sai", + "omo i", + "ta no", + "ware ru", + "co m", + "hitsu you", + "ka shi", + "re nai", + "kan kei", + "a to", + "ga tte", + "o chi", + "mo tsu", + "in g", + "son zai", + "l l", + "o re", + "tai shite", + "a me", + "sei mei", + "ka no", + "gi ri", + "kangae ru", + "yu e", + "a sa", + "o naji", + "yo ru", + "ni ku", + "osa ka", + "suko shi", + "c k", + "ta ma", + "kano jo", + "ki te", + "mon dai", + "a mari", + "e ki", + "ko jin", + "ha ya", + "i t", + "de te", + "atara shii", + "a wa", + "ga kkou", + "tsu zu", + "shu kan", + "i mashita", + "mi na", + "ata e", + "da rou", + "hatara ku", + "ga ta", + "da chi", + "ma tsu", + "ari masen", + "sei butsu", + "mi tsu", + "he ya", + "yasu i", + "d i", + "de ni", + "no ko", + "ha ha", + "do mo", + "ka mi", + "su deni", + "na o", + "ra ku", + "i ke", + "a ki", + "me ta", + "l o", + "ko domo", + "so shite", + "ga me", + "ba kari", + "to te", + "ha tsu", + "mi se", + "moku teki", + "da kara", + "s z", + "e l", + "g y", + "e n", + "t t", + "e m", + "a n", + "a k", + "e r", + "a z", + "a l", + "e t", + "o l", + "e g", + "e k", + "m i", + "o n", + "é s", + "c s", + "a t", + "á r", + "h o", + "e z", + "á l", + "i s", + "á n", + "o r", + "a r", + "e gy", + "e s", + "é r", + "á t", + "o tt", + "e tt", + "m eg", + "t a", + "o k", + "o s", + "ho gy", + "n em", + "é g", + "n y", + "k i", + "é l", + "h a", + "á s", + "ü l", + "i n", + "mi n", + "n a", + "e d", + "o m", + "i k", + "k ö", + "m a", + "n i", + "v a", + "v ol", + "é t", + "b b", + "f el", + "i g", + "l e", + "r a", + "é n", + "t e", + "d e", + "a d", + "ó l", + "b e", + "on d", + "j a", + "r e", + "u l", + "b en", + "n ek", + "u t", + "vol t", + "b an", + "ö r", + "o g", + "a p", + "o d", + "á g", + "n k", + "é k", + "v al", + "k or", + "a m", + "i l", + "í t", + "á k", + "b a", + "u d", + "sz er", + "min d", + "o z", + "é p", + "el l", + "ér t", + "m ond", + "i t", + "sz t", + "n ak", + "a mi", + "n e", + "ő l", + "cs ak", + "n é", + "ma g", + "ol y", + "m er", + "ál l", + "án y", + "ö n", + "ö l", + "min t", + "m ár", + "ö tt", + "na gy", + "é sz", + "az t", + "el ő", + "t ud", + "o t", + "é ny", + "á z", + "m ég", + "kö z", + "el y", + "s ég", + "en t", + "s em", + "ta m", + "h et", + "h al", + "f i", + "a s", + "v an", + "ho z", + "v e", + "u k", + "k ez", + "á m", + "v el", + "b er", + "a j", + "u nk", + "i z", + "va gy", + "m os", + "sz em", + "em ber", + "f og", + "mer t", + "ü k", + "l en", + "ö s", + "e j", + "t al", + "h at", + "t ak", + "h i", + "m ás", + "s ág", + "ett e", + "l eg", + "ü nk", + "h át", + "sz a", + "on y", + "ez t", + "mind en", + "en d", + "ül t", + "h an", + "j ó", + "k is", + "á j", + "in t", + "ú gy", + "i d", + "mos t", + "ar t", + "í r", + "k er", + "i tt", + "a tt", + "el t", + "mond ta", + "k ell", + "l á", + "ak i", + "ál t", + "ér d", + "t ö", + "l an", + "v ár", + "h ol", + "t el", + "l át", + "ő k", + "v et", + "s e", + "ut án", + "k ét", + "na p", + "í v", + "ál y", + "v ég", + "ö k", + "i r", + "d ul", + "v is", + "né z", + "t er", + "á ban", + "k ül", + "ak kor", + "k ap", + "sz él", + "y en", + "ú j", + "i m", + "oly an", + "es en", + "k ed", + "h ely", + "t ör", + "b ól", + "el m", + "r á", + "ár a", + "r ó", + "l ó", + "vol na", + "t an", + "le het", + "e bb", + "t en", + "t ek", + "s ok", + "k al", + "f or", + "u g", + "ol t", + "k a", + "ek et", + "b or", + "f ej", + "g ond", + "a g", + "ak ar", + "f él", + "ú l", + "b el", + "ott a", + "mi t", + "val ami", + "j el", + "é d", + "ar c", + "u r", + "hal l", + "t i", + "f öl", + "á ba", + "ol g", + "ki r", + "ol d", + "m ar", + "k érd", + "j ár", + "ú r", + "sz e", + "z s", + "él et", + "j át", + "o v", + "u s", + "é z", + "v il", + "v er", + "ő r", + "á d", + "ö g", + "le sz", + "on t", + "b iz", + "k oz", + "á bb", + "kir ály", + "es t", + "a b", + "en g", + "ig az", + "b ar", + "ha j", + "d i", + "o b", + "k od", + "r ól", + "v ez", + "tö bb", + "sz ó", + "é ben", + "ö t", + "ny i", + "t á", + "sz ól", + "gond ol", + "eg ész", + "í gy", + "ő s", + "o bb", + "os an", + "b ől", + "a bb", + "c i", + "ő t", + "n ál", + "k ép", + "azt án", + "v i", + "t art", + "be szél", + "m en", + "elő tt", + "a szt", + "ma j", + "kö r", + "han g", + "í z", + "in cs", + "a i", + "é v", + "ó d", + "ó k", + "hoz z", + "t em", + "ok at", + "an y", + "nagy on", + "h áz", + "p er", + "p ed", + "ez te", + "et len", + "nek i", + "maj d", + "sz ony", + "án ak", + "fel é", + "egy szer", + "j e", + "ad t", + "gy er", + "ami kor", + "f oly", + "sz ak", + "ő d", + "h ú", + "á sz", + "am ely", + "h ar", + "ér e", + "il yen", + "od a", + "j ák", + "t ár", + "á val", + "l ak", + "t ó", + "m ent", + "gy an", + "él y", + "ú t", + "v ar", + "kez d", + "m ell", + "mi kor", + "h ez", + "val ó", + "k o", + "m es", + "szer et", + "r end", + "l et", + "vis sza", + "ig en", + "f ő", + "va s", + "as szony", + "r ől", + "ped ig", + "p i", + "sz ép", + "t ák", + "ö v", + "an i", + "vil ág", + "p en", + "mag a", + "t et", + "sz ik", + "é j", + "én t", + "j ött", + "s an", + "sz í", + "i de", + "g at", + "ett em", + "ul t", + "h ány", + "ás t", + "a hol", + "ők et", + "h ár", + "k el", + "n ő", + "cs i", + "tal ál", + "el te", + "lá tt", + "tör t", + "ha gy", + "e sz", + "s en", + "n él", + "p ar", + "v ál", + "k ut", + "l ány", + "ami t", + "s ő", + "ell en", + "mag át", + "in k", + "u gyan", + "kül ön", + "a sz", + "mind ig", + "l ép", + "tal án", + "u n", + "sz or", + "k e", + "il lan", + "n incs", + "z et", + "vagy ok", + "tel en", + "is mer", + "s or", + "is ten", + "ít ott", + "j obb", + "v es", + "dul t", + "j uk", + "sz en", + "r o", + "ö m", + "l ett", + "k ar", + "egy ik", + "b ár", + "sz i", + "sz ív", + "az on", + "e szt", + "föl d", + "kut y", + "p illan", + "f ér", + "k om", + "t ől", + "t ű", + "é be", + "t ött", + "bar át", + "í g", + "a hogy", + "e h", + "e p", + "s o", + "v en", + "jel ent", + "t at", + "sz eg", + "mint ha", + "f al", + "egy en", + "mi l", + "sza b", + "r i", + "é m", + "biz ony", + "j on", + "ör eg", + "d olg", + "cs ap", + "ti szt", + "áll t", + "an cs", + "id ő", + "k at", + "ü gy", + "mi ért", + "ó t", + "ü r", + "cs in", + "h az", + "b et", + "én ek", + "v ér", + "j ól", + "al att", + "m ely", + "l o", + "sem mi", + "ny ug", + "v ág", + "kö vet", + "ös sze", + "ma d", + "l i", + "a cs", + "fi ú", + "kö n", + "más ik", + "j ön", + "sz ám", + "g er", + "s ó", + "r ész", + "k ér", + "z el", + "é vel", + "e o", + "e u", + "a n", + "eu l", + "eu n", + "eo n", + "a e", + "d a", + "a l", + "s s", + "i n", + "i l", + "a g", + "an g", + "y eon", + "y eo", + "d o", + "c h", + "n g", + "j i", + "h an", + "g a", + "g o", + "u i", + "h ae", + "a m", + "u l", + "u n", + "g eo", + "s i", + "n eun", + "ss da", + "s eo", + "eon g", + "y o", + "i da", + "t t", + "k k", + "j eo", + "d eul", + "w a", + "eu m", + "g e", + "o n", + "o g", + "s al", + "m an", + "yeon g", + "geo s", + "h ag", + "an eun", + "j a", + "g i", + "s u", + "i ss", + "o l", + "d ae", + "eo b", + "h a", + "j u", + "eo l", + "g eu", + "j eong", + "s ae", + "do e", + "g eul", + "s eu", + "s in", + "eul o", + "b n", + "s ang", + "bn ida", + "h al", + "b o", + "han eun", + "m al", + "i m", + "m o", + "b u", + "jeo g", + "sae ng", + "in eun", + "an h", + "m a", + "sal am", + "j o", + "s a", + "eo m", + "n ae", + "w i", + "l o", + "g wa", + "yeo l", + "n a", + "e seo", + "y e", + "m yeon", + "tt ae", + "h w", + "j e", + "eob s", + "j ang", + "g u", + "g w", + "il eul", + "yeo g", + "j eon", + "si g", + "j ag", + "j in", + "y u", + "o e", + "s e", + "hag o", + "d eun", + "y a", + "m un", + "s eong", + "g ag", + "h am", + "d ang", + "b a", + "l eul", + "s il", + "do ng", + "kk a", + "b al", + "da l", + "han da", + "eo ssda", + "ae g", + "l i", + "ha ji", + "s eon", + "o ng", + "hae ssda", + "d e", + "i ssda", + "e ge", + "b un", + "m ul", + "ju ng", + "ji g", + "m u", + "iss neun", + "b i", + "g eun", + "seu bnida", + "w on", + "p p", + "d aneun", + "eo h", + "d eo", + "ga m", + "j al", + "hae ng", + "ag o", + "y ang", + "b ul", + "b ang", + "u m", + "s o", + "h i", + "j ae", + "si m", + "saeng gag", + "hag e", + "s og", + "eo ss", + "d an", + "ja sin", + "j il", + "eo g", + "g yeong", + "doe n", + "go ng", + "m i", + "ch i", + "d eu", + "d eon", + "hae ss", + "d u", + "n am", + "eun g", + "jo h", + "n al", + "m yeong", + "w o", + "eon a", + "i go", + "g yeol", + "y ag", + "gw an", + "ul i", + "yo ng", + "n o", + "l yeo", + "j og", + "eoh ge", + "ga t", + "b og", + "mo s", + "t ong", + "ch a", + "man h", + "jeo l", + "geo l", + "h oe", + "ag a", + "n aneun", + "g an", + "un eun", + "ch eol", + "ch e", + "do l", + "b on", + "b an", + "ba d", + "ch u", + "ham yeon", + "yeo ssda", + "i bnida", + "g ye", + "eo s", + "hw al", + "salam deul", + "ji man", + "dang sin", + "ji b", + "ttae mun", + "m ae", + "i b", + "e neun", + "eu g", + "jeo m", + "geul eon", + "h wa", + "a ssda", + "b eob", + "bu t", + "b ae", + "yeo ss", + "ch in", + "ch aeg", + "g eon", + "g ae", + "nae ga", + "i ga", + "m og", + "sig an", + "g il", + "h yeon", + "l yeog", + "gu g", + "p yeon", + "s an", + "w ae", + "j ul", + "s eul", + "deun g", + "haji man", + "eum yeon", + "p il", + "m ol", + "n eu", + "a ss", + "n yeon", + "t ae", + "h u", + "p yo", + "s ul", + "g ang", + "j ineun", + "b eon", + "ha da", + "seo l", + "si p", + "dal eun", + "a p", + "sal m", + "g yo", + "ch eon", + "hag i", + "in a", + "cheol eom", + "g al", + "il a", + "kka ji", + "anh neun", + "ha bnida", + "tt eon", + "n u", + "hae seo", + "doen da", + "s ol", + "tt al", + "l a", + "il o", + "seu b", + "b yeon", + "m yeo", + "b eol", + "s on", + "n un", + "j un", + "j am", + "j eung", + "tt o", + "e n", + "mo m", + "h o", + "ch im", + "hw ang", + "eun eun", + "jo ng", + "bo da", + "n ol", + "n eom", + "but eo", + "jig eum", + "eobs da", + "dae lo", + "i g", + "y ul", + "p yeong", + "seon eun", + "sal ang", + "seu t", + "h im", + "n an", + "h eom", + "h yang", + "p i", + "gw ang", + "eobs neun", + "hw ag", + "ge ss", + "jag i", + "il eon", + "wi hae", + "dae han", + "ga ji", + "m eog", + "j yeo", + "cha j", + "b yeong", + "eo d", + "g yeo", + "do n", + "eo ji", + "g ul", + "mo deun", + "j on", + "in saeng", + "geul ae", + "h ang", + "sa sil", + "si b", + "ch al", + "il ago", + "doe l", + "g eum", + "doe neun", + "b ol", + "ga jang", + "geul igo", + "e l", + "h yeong", + "haeng bog", + "ch ul", + "h on", + "ch ae", + "s am", + "m ang", + "in da", + "da m", + "w ol", + "ch oe", + "d ul", + "si jag", + "ch eong", + "il aneun", + "ul ineun", + "ae n", + "kk e", + "mun je", + "a do", + "t eu", + "g un", + "geun eun", + "b ge", + "ch eo", + "b aeg", + "ju g", + "t a", + "sang dae", + "geu geos", + "do g", + "eu s", + "deu s", + "ja b", + "h yeo", + "tt eohge", + "u g", + "ma j", + "ch il", + "s wi", + "j ileul", + "ch ang", + "g aneun", + "m ag", + "i ji", + "da go", + "m in", + "yo han", + "t eug", + "pp un", + "al eul", + "haeng dong", + "p o", + "m il", + "ch am", + "se sang", + "e do", + "p an", + "man deul", + "am yeon", + "a b", + "kk ae", + "b ag", + "i deul", + "p um", + "m eol", + "s un", + "n eul", + "ham kke", + "chu ng", + "da b", + "yu g", + "s ag", + "gwang ye", + "il eohge", + "bal o", + "neun de", + "ham yeo", + "go s", + "geul eoh", + "an ila", + "bang beob", + "da si", + "b yeol", + "g yeon", + "gam jeong", + "on eul", + "j aneun", + "yeo m", + "l ago", + "i gi", + "hw an", + "t eul", + "eo seo", + "si k", + "ch o", + "jag a", + "geul eom", + "geul eona", + "jeong do", + "g yeog", + "geul eohge", + "geu deul", + "eu t", + "im yeon", + "j jae", + "k eun", + "i sang", + "mal haessda", + "eu ge", + "no p", + "in gan", + "bo myeon", + "t aeg", + "seu s", + "d wi", + "s aneun", + "w an", + "anh go", + "t an", + "nu gu", + "su ng", + "da myeon", + "a deul", + "p eul", + "ttal a", + "d i", + "geos do", + "a ji", + "m eon", + "eum yeo", + "dol og", + "neun g", + "mo du", + "क े", + "ह ै", + "े ं", + "् र", + "ा र", + "न े", + "य ा", + "म ें", + "स े", + "क ी", + "क ा", + "ो ं", + "त ा", + "क र", + "स ्", + "क ि", + "क ो", + "र ्", + "न ा", + "क ्", + "ह ी", + "औ र", + "प र", + "त े", + "ह ो", + "प ्र", + "ा न", + "् य", + "ल ा", + "व ा", + "ल े", + "स ा", + "है ं", + "ल ि", + "ज ा", + "ह ा", + "भ ी", + "व ि", + "इ स", + "त ी", + "न ्", + "र ा", + "म ा", + "द े", + "द ि", + "ब ा", + "त ि", + "थ ा", + "न ि", + "क ार", + "ए क", + "ही ं", + "ह ु", + "ं ग", + "ै ं", + "न ी", + "स ी", + "अ प", + "त ्", + "न हीं", + "र ी", + "म े", + "म ु", + "ि त", + "त ो", + "प ा", + "ल ी", + "लि ए", + "ग ा", + "ल ्", + "र ह", + "र े", + "क् ष", + "म ैं", + "स म", + "उ स", + "ज ि", + "त ्र", + "म ि", + "च ा", + "ो ग", + "स ं", + "द ्", + "स ि", + "आ प", + "त ु", + "द ा", + "क ु", + "य ों", + "व े", + "ज ी", + "् या", + "उ न", + "ि क", + "य े", + "भ ा", + "् ट", + "ह म", + "स् ट", + "श ा", + "ड ़", + "ं द", + "ख ा", + "म ्", + "श ्", + "य ह", + "स क", + "प ू", + "कि या", + "अप ने", + "र ू", + "स ु", + "म ी", + "ह ि", + "ज ो", + "थ े", + "र ि", + "द ी", + "थ ी", + "ग ी", + "ल ोग", + "ग या", + "त र", + "न् ह", + "च ्", + "व ार", + "ब ी", + "प ्", + "द ो", + "ट ी", + "श ि", + "कर ने", + "ग े", + "ै से", + "इ न", + "ं ड", + "सा थ", + "प ु", + "ब े", + "ब ार", + "व ी", + "अ न", + "ह र", + "उ न्ह", + "हो ता", + "ज ब", + "कु छ", + "म ान", + "क ्र", + "ब ि", + "प ह", + "फ ि", + "स र", + "ार ी", + "र ो", + "द ू", + "क हा", + "त क", + "श न", + "ब ्", + "स् थ", + "व ह", + "बा द", + "ओ ं", + "ग ु", + "ज ्", + "्र े", + "ग र", + "रह े", + "व र्", + "ह ू", + "ार ्", + "प ी", + "ब हु", + "मु झ", + "्र ा", + "दि या", + "स ब", + "कर ते", + "अप नी", + "बहु त", + "क ह", + "ट े", + "हु ए", + "कि सी", + "र हा", + "ष ्ट", + "ज ़", + "ब ना", + "स ो", + "ड ि", + "को ई", + "व ्य", + "बा त", + "र ु", + "व ो", + "मुझ े", + "द् ध", + "च ार", + "मे रे", + "व र", + "्र ी", + "जा ता", + "न ों", + "प्र ा", + "दे ख", + "ट ा", + "क् या", + "अ ध", + "ल ग", + "ल ो", + "प ि", + "य ु", + "च े", + "जि स", + "ं त", + "ान ी", + "प ै", + "ज न", + "ार े", + "च ी", + "मि ल", + "द ु", + "दे श", + "च् छ", + "ष ्", + "स ू", + "ख े", + "च ु", + "ि या", + "ल गा", + "ब ु", + "उन के", + "ज् ञ", + "क्ष ा", + "त रह", + "्या दा", + "वा ले", + "पू र्", + "मैं ने", + "का म", + "रू प", + "हो ती", + "उ प", + "ज ान", + "प्र कार", + "भ ार", + "म न", + "हु आ", + "ट र", + "हू ँ", + "पर ि", + "पा स", + "अन ु", + "रा ज", + "लोग ों", + "अ ब", + "सम झ", + "ड ी", + "म ौ", + "श ु", + "च ि", + "प े", + "क ृ", + "सक ते", + "म ह", + "य ोग", + "द र्", + "उ से", + "ं ध", + "ड ा", + "जा ए", + "ब ो", + "ू ल", + "म ो", + "ों ने", + "ं स", + "तु म", + "पह ले", + "ब ता", + "त था", + "य ो", + "ग ई", + "उ त्", + "सक ता", + "क म", + "ज ्यादा", + "र ख", + "सम य", + "ार ा", + "अ गर", + "स् त", + "च ल", + "फि र", + "वार ा", + "कर ना", + "श ी", + "ग ए", + "ब न", + "ौ र", + "हो ने", + "चा ह", + "ख ु", + "हा ँ", + "उन्ह ें", + "उन्ह ोंने", + "छ ो", + "म् ह", + "प्र ति", + "नि क", + "व न", + "्य ू", + "र ही", + "तु म्ह", + "ज ैसे", + "ि यों", + "क् यों", + "ल ों", + "फ ़", + "ं त्र", + "हो ते", + "क् ति", + "त ्य", + "कर ्", + "क ई", + "व ं", + "कि न", + "प ो", + "कार ण", + "ड़ ी", + "भ ि", + "इस के", + "ब र", + "उस के", + "द् वारा", + "श े", + "क ॉ", + "दि न", + "न् न", + "ड़ ा", + "स् व", + "नि र्", + "मु ख", + "लि या", + "ट ि", + "ज्ञ ान", + "क् त", + "द ्र", + "ग ्", + "क् स", + "म ै", + "ग ो", + "ज े", + "ट ्र", + "म ार", + "त् व", + "ध ार", + "भा व", + "कर ता", + "ख ि", + "क ं", + "चा हि", + "य र", + "प् त", + "क ों", + "ं च", + "ज ु", + "म त", + "अ च्छ", + "हु ई", + "क भी", + "ले किन", + "भ ू", + "अप ना", + "दू स", + "चाहि ए", + "य ू", + "घ र", + "सब से", + "मे री", + "ना म", + "ढ ़", + "ं ट", + "ें गे", + "ब ै", + "फ ा", + "ए वं", + "य ी", + "ग ्र", + "क्ष े", + "आ ज", + "आप को", + "भा ग", + "ठ ा", + "क ै", + "भार त", + "उन की", + "प हु", + "स भी", + "ध ा", + "ण ा", + "स ान", + "हो गा", + "त ब", + "स ंग", + "प र्", + "अ व", + "त ना", + "ग ि", + "य न", + "स् था", + "च ित", + "ट ्", + "छ ा", + "जा ने", + "क्षे त्र", + "वा ली", + "पूर् ण", + "स मा", + "कार ी" + ] + } +} \ No newline at end of file diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/zh_num2words.py b/diffsynth_engine/models/ace_step/lyric_tokenizer/zh_num2words.py new file mode 100644 index 00000000..d47ed694 --- /dev/null +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/zh_num2words.py @@ -0,0 +1,1113 @@ +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 - 2022 Jiayu DU +# copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/zh_num2words.py +import re +import string +import sys + +# fmt: off + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ["呃", "啊"] + +ER_WHITELIST = ( + "(儿女|儿子|儿孙|女儿|儿媳|妻儿|" + "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|" + "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|" + "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)" +) +ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] + +CURRENCY_NAMES = "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)" +) + + +# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CN_PUNCS_STOP = "!?。。" +CN_PUNCS_NONSTOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-" +CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP + +PUNCS = CN_PUNCS + string.punctuation +PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma + + +# https://zh.wikipedia.org/wiki/全行和半行 +QJ2BJ = { + " ": " ", + "!": "!", + """: '"', + "#": "#", + "$": "$", + "%": "%", + "&": "&", + "'": "'", + "(": "(", + ")": ")", + "*": "*", + "+": "+", + ",": ",", + "-": "-", + ".": ".", + "/": "/", + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", + ":": ":", + ";": ";", + "<": "<", + "=": "=", + ">": ">", + "?": "?", + "@": "@", + "A": "A", + "B": "B", + "C": "C", + "D": "D", + "E": "E", + "F": "F", + "G": "G", + "H": "H", + "I": "I", + "J": "J", + "K": "K", + "L": "L", + "M": "M", + "N": "N", + "O": "O", + "P": "P", + "Q": "Q", + "R": "R", + "S": "S", + "T": "T", + "U": "U", + "V": "V", + "W": "W", + "X": "X", + "Y": "Y", + "Z": "Z", + "[": "[", + "\": "\\", + "]": "]", + "^": "^", + "_": "_", + "`": "`", + "a": "a", + "b": "b", + "c": "c", + "d": "d", + "e": "e", + "f": "f", + "g": "g", + "h": "h", + "i": "i", + "j": "j", + "k": "k", + "l": "l", + "m": "m", + "n": "n", + "o": "o", + "p": "p", + "q": "q", + "r": "r", + "s": "s", + "t": "t", + "u": "u", + "v": "v", + "w": "w", + "x": "x", + "y": "y", + "z": "z", + "{": "{", + "|": "|", + "}": "}", + "~": "~", +} +QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "") + + +# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources: +# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total +CN_CHARS_COMMON = ( + "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举" + "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互" + "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从" + "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优" + "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚" + "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣" + "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯" + "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌" + "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚" + "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六" + "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况" + "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈" + "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐" + "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼" + "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹" + "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵" + "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔" + "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊" + "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆" + "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎" + "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌" + "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛" + "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴" + "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌" + "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡" + "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓" + "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢" + "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥" + "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩" + "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基" + "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填" + "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复" + "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖" + "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮" + "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈" + "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻" + "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱" + "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽" + "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾" + "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝" + "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山" + "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃" + "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧" + "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉" + "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡" + "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐" + "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷" + "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖" + "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循" + "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀" + "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓" + "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟" + "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭" + "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥" + "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我" + "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔" + "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡" + "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥" + "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫" + "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎" + "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭" + "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘" + "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢" + "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整" + "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗" + "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝" + "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡" + "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜" + "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽" + "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅" + "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔" + "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩" + "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯" + "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘" + "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂" + "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱" + "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞" + "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正" + "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒" + "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮" + "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽" + "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽" + "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼" + "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿" + "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸" + "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅" + "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟" + "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆" + "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕" + "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶" + "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧" + "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵" + "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈" + "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯" + "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵" + "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍" + "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷" + "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎" + "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎" + "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊" + "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊" + "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙" + "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱" + "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯" + "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐" + "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒" + "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩" + "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙" + "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾" + "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦" + "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知" + "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮" + "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍" + "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷" + "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭" + "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣" + "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗" + "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立" + "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯" + "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅" + "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾" + "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮" + "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢" + "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁" + "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩" + "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕" + "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐" + "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲" + "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者" + "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋" + "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸" + "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳" + "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒" + "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻" + "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般" + "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊" + "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈" + "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆" + "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐" + "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛" + "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥" + "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著" + "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺" + "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷" + "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸" + "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱" + "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆" + "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗" + "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃" + "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡" + "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒" + "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂" + "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉" + "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认" + "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词" + "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请" + "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡" + "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹" + "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼" + "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤" + "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑" + "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮" + "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇" + "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较" + "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄" + "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆" + "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒" + "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱" + "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴" + "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝" + "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭" + "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒" + "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼" + "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨" + "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐" + "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹" + "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣" + "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼" + "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶" + "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈" + "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳" + "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰" + "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵" + "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额" + "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰" + "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥" + "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑" + "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高" + "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾" + "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨" + "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓" + "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶" + "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟" + "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄" + "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷" + "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃" + "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡" + "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽" + "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯" + "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟" + "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟" + "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟" + "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓" +) +CN_CHARS_EXT = "吶诶屌囧飚屄" + +CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT +IN_CH_CHARS = {c: True for c in CN_CHARS} + +EN_CHARS = string.ascii_letters + string.digits +IN_EN_CHARS = {c: True for c in EN_CHARS} + +VALID_CHARS = CN_CHARS + EN_CHARS + " " +IN_VALID_CHARS = {c: True for c in VALID_CHARS} + + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + # self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + if small_unit: + return ChineseNumberUnit( + power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + else: + raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + def get_value(value_string, use_zeros=True): + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string)) + result_string = value_string[: -result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :]) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError("invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] + and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]] + ): + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts]) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts]) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", 1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", 1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +def normalize_nsw(raw_text): + text = "^" + raw_text + "$" + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # restore P2P, O2O, B2C, B2B etc + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + + return text.lstrip("^").rstrip("$") + + +def remove_erhua(text): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + new_str = "" + while re.search("儿", text): + a = re.search("儿", text).span() + remove_er_flag = 0 + + if ER_WHITELIST_PATTERN.search(text): + b = ER_WHITELIST_PATTERN.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0: + new_str = new_str + text[0 : a[0]] + text = text[a[1] :] + else: + new_str = new_str + text[0 : b[1]] + text = text[b[1] :] + + text = new_str + text + return text + + +def remove_space(text): + tokens = text.split() + new = [] + for k, t in enumerate(tokens): + if k != 0: + if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]): + new.append(" ") + new.append(t) + return "".join(new) + + +class TextNorm: + def __init__( + self, + to_banjiao: bool = False, + to_upper: bool = False, + to_lower: bool = False, + remove_fillers: bool = False, + remove_erhua: bool = False, + check_chars: bool = False, + remove_space: bool = False, + cc_mode: str = "", + ): + self.to_banjiao = to_banjiao + self.to_upper = to_upper + self.to_lower = to_lower + self.remove_fillers = remove_fillers + self.remove_erhua = remove_erhua + self.check_chars = check_chars + self.remove_space = remove_space + + self.cc = None + if cc_mode: + from opencc import OpenCC # Open Chinese Convert: pip install opencc + + self.cc = OpenCC(cc_mode) + + def __call__(self, text): + if self.cc: + text = self.cc.convert(text) + + if self.to_banjiao: + text = text.translate(QJ2BJ_TRANSFORM) + + if self.to_upper: + text = text.upper() + + if self.to_lower: + text = text.lower() + + if self.remove_fillers: + for c in FILLER_CHARS: + text = text.replace(c, "") + + if self.remove_erhua: + text = remove_erhua(text) + + text = normalize_nsw(text) + + text = text.translate(PUNCS_TRANSFORM) + + if self.check_chars: + for c in text: + if not IN_VALID_CHARS.get(c): + print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) + return "" + + if self.remove_space: + text = remove_space(text) + + return text diff --git a/diffsynth_engine/models/ace_step/vae/dcae.py b/diffsynth_engine/models/ace_step/vae/dcae.py index 822c3e57..d56e9626 100644 --- a/diffsynth_engine/models/ace_step/vae/dcae.py +++ b/diffsynth_engine/models/ace_step/vae/dcae.py @@ -1,7 +1,12 @@ +from typing import Any, Dict, Tuple, Union +import json + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple, Union + +from diffsynth_engine.models.base import PreTrainedModel +from diffsynth_engine.utils.constants import ACE_VAE_DCAE_CONFIG_FILE class RMSNorm(nn.RMSNorm): @@ -572,7 +577,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class DCAE(nn.Module): +class DCAE(PreTrainedModel): def __init__( self, in_channels: int = 2, @@ -623,3 +628,24 @@ def encode(self, x: torch.Tensor) -> torch.Tensor: def decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + config: Dict[str, Any], + device: str = "cuda:0", + dtype: torch.dtype = torch.float32, + ) -> "DCAE": + model = cls(**config, device="meta", dtype=dtype) + model.requires_grad_(False) + model.load_state_dict(state_dict, assign=True) + model.to(device=device, dtype=dtype, non_blocking=True) + return model + + @staticmethod + def get_model_config() -> dict: + config_file = ACE_VAE_DCAE_CONFIG_FILE + with open(config_file, "r") as f: + config = json.load(f) + return config \ No newline at end of file diff --git a/diffsynth_engine/models/ace_step/vae/hifi_gan.py b/diffsynth_engine/models/ace_step/vae/hifi_gan.py index a1936a2d..536b27e3 100644 --- a/diffsynth_engine/models/ace_step/vae/hifi_gan.py +++ b/diffsynth_engine/models/ace_step/vae/hifi_gan.py @@ -1,5 +1,6 @@ -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Dict, Any import math +import json from functools import partial import numpy as np @@ -9,6 +10,9 @@ from torch.nn.utils.parametrizations import weight_norm from torchaudio.transforms import MelScale +from diffsynth_engine.models.base import PreTrainedModel +from diffsynth_engine.utils.constants import ACE_VAE_VOCODER_CONFIG_FILE + class LinearSpectrogram(nn.Module): def __init__( @@ -543,7 +547,7 @@ def forward(self, x, template=None): return x -class ADaMoSHiFiGANV1(nn.Module): +class ADaMoSHiFiGANV1(PreTrainedModel): def __init__( self, input_channels: int = 128, @@ -621,3 +625,24 @@ def forward(self, mel): y = self.backbone(mel) y = self.head(y) return y + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + config: Dict[str, Any], + device: str = "cuda:0", + dtype: torch.dtype = torch.float32, + ) -> "ADaMoSHiFiGANV1": + model = cls(**config, device="meta", dtype=dtype) + model.requires_grad_(False) + model.load_state_dict(state_dict, assign=True) + model.to(device=device, dtype=dtype, non_blocking=True) + return model + + @staticmethod + def get_model_config() -> dict: + config_file = ACE_VAE_VOCODER_CONFIG_FILE + with open(config_file, "r") as f: + config = json.load(f) + return config diff --git a/diffsynth_engine/models/ace_step/vae/music_dcae.py b/diffsynth_engine/models/ace_step/vae/music_dcae.py index 8264e2e8..b6c9fda3 100644 --- a/diffsynth_engine/models/ace_step/vae/music_dcae.py +++ b/diffsynth_engine/models/ace_step/vae/music_dcae.py @@ -11,13 +11,13 @@ class MusicDCAE(PreTrainedModel): def __init__( self, - dace: DCAE, + dcae: DCAE, vocoder: ADaMoSHiFiGANV1, source_sample_rate: int = 48000, ): super(MusicDCAE, self).__init__() - self.dcae = dace + self.dcae = dcae self.vocoder = vocoder self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) self.transform = transforms.Compose([transforms.Normalize(0.5, 0.5)]) @@ -73,7 +73,7 @@ def decode(self, latents: torch.Tensor, sr: int): pred_wavs = [] for latent in latents: - mels = self.dcae.decoder(latent.unsqueeze(0)) + mels = self.dcae.decoder(latent[None]) mels = mels * 0.5 + 0.5 mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value @@ -86,4 +86,4 @@ def decode(self, latents: torch.Tensor, sr: int): wav = resampler(wav.cpu().float()) pred_wavs.append(wav) - return sr, pred_wavs + return pred_wavs diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 8d6b7b76..f91db3e7 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -1,29 +1,501 @@ -import math +from typing import Tuple -import numpy as np import torch import torch.nn.functional as F import torch.distributed as dist -from torchvision.transforms.functional import pil_to_tensor from typing import Callable, List, Optional from tqdm import tqdm -from PIL import Image -from diffsynth_engine.configs import ACEStepPipelineConfig +from diffsynth_engine.configs import ACEStepPipelineConfig, ACEStateDicts from diffsynth_engine.models.ace_step.ace_dit import ACEStepDiT from diffsynth_engine.models.ace_step.vae.music_dcae import MusicDCAE +from diffsynth_engine.models.ace_step.vae.dcae import DCAE +from diffsynth_engine.models.ace_step.vae.hifi_gan import ADaMoSHiFiGANV1 +from diffsynth_engine.models.ace_step.lyric_tokenizer.lyric_tokenizer import ( + VoiceBpeTokenizer, + SUPPORT_LANGUAGES, + structure_pattern, +) +from diffsynth_engine.models.ace_step.lyric_tokenizer.lang_segment import LangSegment, default +from diffsynth_engine.models.ace_step.ace_text_encoder import ACETextEncoder +from diffsynth_engine.tokenizers import WanT5Tokenizer +from diffsynth_engine.models.basic.lora import LoRAContext +from diffsynth_engine.algorithm.noise_scheduler.flow_match import RecifitedFlowScheduler +from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler from diffsynth_engine.pipelines import BasePipeline +from diffsynth_engine.utils.constants import WAN_TOKENIZER_CONF_PATH +from diffsynth_engine.utils.download import fetch_model +from diffsynth_engine.utils.fp8_linear import enable_fp8_linear from diffsynth_engine.utils import logging logger = logging.get_logger(__name__) +def fwd_with_temperature( + inputs, model_fwd_func, get_hooked_layer_func, layer_start_idx, layer_end_idx, temperature=0.01 +): + def hook(module, input, output): + output[:] *= temperature + return output + + handlers = [] + for i in range(layer_start_idx, layer_end_idx): + layer = get_hooked_layer_func(i) + if isinstance(layer, list): + for sub_layer in layer: + handler = sub_layer.register_forward_hook(hook) + handlers.append(handler) + else: + handler = layer.register_forward_hook(hook) + handlers.append(handler) + + with torch.no_grad(): + prompt_emb = model_fwd_func(**inputs) + + for handler in handlers: + handler.remove() + return prompt_emb + + +class MomentumBuffer: + def __init__(self, momentum: float = -0.75): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def project( + v0: torch.Tensor, # [B, C, H, W] + v1: torch.Tensor, # [B, C, H, W] + dims=[-1, -2], +): + if v0.device.type == "mps": + v0, v1 = v0.cpu(), v1.cpu() + + v0, v1 = v0.double(), v1.double() + v1 = F.normalize(v1, dim=dims) + v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel.to(v0), v0_orthogonal.to(v0) + + +def apg_forward( + pred_cond: torch.Tensor, # [B, C, H, W] + pred_uncond: torch.Tensor, # [B, C, H, W] + guidance_scale: float, + momentum_buffer: MomentumBuffer, + eta: float = 0.0, + norm_threshold: float = 2.5, + dims=[-1, -2], +): + diff = pred_cond - pred_uncond + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + diff_norm = diff.norm(p=2, dim=dims, keepdim=True) + scale_factor = torch.minimum(torch.ones_like(diff), norm_threshold / diff_norm) + diff *= scale_factor + + diff_parallel, diff_orthogonal = project(diff, pred_cond, dims) + normalized_update = diff_orthogonal + eta * diff_parallel + pred_guided = pred_cond + (guidance_scale - 1) * normalized_update + return pred_guided + + class ACEStepMusicPipeline(BasePipeline): def __init__( self, config: ACEStepPipelineConfig, + tokenizer: WanT5Tokenizer, + text_encoder: ACETextEncoder, dit: ACEStepDiT, vae: MusicDCAE, ): - pass \ No newline at end of file + super().__init__( + vae_tiled=config.vae_tiled, + vae_tile_size=config.vae_tile_size, + vae_tile_stride=config.vae_tile_stride, + device=config.device, + dtype=config.model_dtype, + ) + self.config = config + # sampler + self.noise_scheduler = RecifitedFlowScheduler(shift=3.0) + self.sampler = FlowMatchEulerSampler() + # models + self.lyric_tokenizer = VoiceBpeTokenizer() + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.dit = dit + self.vae = vae + self.model_names = ["text_encoder", "dit", "vae"] + # language segment + self.lang_segment = LangSegment() + self.lang_segment.setfilters(default) + + def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False): + # assert self.config.tp_degree is None or self.config.tp_degree == 1, ( + # "load LoRA is not allowed when tensor parallel is enabled; " + # "set tp_degree=None or tp_degree=1 during pipeline initialization" + # ) + # assert not (self.config.use_fsdp and fused), ( + # "load fused LoRA is not allowed when fully sharded data parallel is enabled; " + # "either load LoRA with fused=False or set use_fsdp=False during pipeline initialization" + # ) + super().load_loras(lora_list, fused, save_original_weight) + + def unload_loras(self): + self.dit.unload_loras() + self.text_encoder.unload_loras() + + def encode_prompt(self, prompt): + self.load_models_to_device(["text_encoder"]) + ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(self.device) + mask = mask.to(self.device) + prompt_emb = self.text_encoder(ids, mask) + return prompt_emb, mask + + def encode_prompt_null(self, prompt): + self.load_models_to_device(["text_encoder"]) + ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) + ids = ids.to(self.device) + mask = mask.to(self.device) + prompt_emb = fwd_with_temperature( + inputs=(ids, mask), + model_fwd_func=self.text_encoder, + get_hooked_layer_func=lambda i: self.text_encoder.blocks[i].attn.q, + layer_start_idx=4, + layer_end_idx=6, + ) + return prompt_emb, mask + + def tokenize_lyric(self, lyrics: str): + lyric_token_idx = [261] + for line in lyrics.split("\n"): + line = line.strip() + if not line: + lyric_token_idx += [2] + continue + + try: + self.lang_segment.getTexts(line) + langCounts = self.lang_segment.getCounts() + language = langCounts[0][0] + if len(langCounts) > 1 and language == "en": + language = langCounts[1][0] + except Exception: + language = "en" + + if language not in SUPPORT_LANGUAGES: + language = "en" + if "zh" in language: + language = "zh" + if "spa" in language: + language = "es" + + try: + if structure_pattern.match(line): + token_idx = self.lyric_tokenizer.encode(line, "en") + else: + token_idx = self.lyric_tokenizer.encode(line, language) + lyric_token_idx += token_idx + [2] + except Exception as e: + logger.warning("tokenize error", e, "for line", line, "major_language", language) + lyric_mask = torch.tensor([1] * len(lyric_token_idx), device=self.device)[None] + lyric_token_idx = torch.tensor(lyric_token_idx, device=self.device)[None] + return lyric_token_idx, lyric_mask + + def predict_noise_with_cfg( + self, + model: ACEStepDiT, + latents: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + context_null: torch.Tensor, + attn_mask: torch.Tensor, + attn_mask_ctx: torch.Tensor, + attn_mask_ctx_null: torch.Tensor, + cfg_scale: float, + momentum_buffer: MomentumBuffer, + ): + if cfg_scale <= 1.0: + return self.predict_noise( + model=model, + latents=latents, + timestep=timestep, + context=context, + attn_mask=attn_mask, + attn_mask_ctx=attn_mask_ctx, + ) + # cfg by predict noise one by one + positive_noise_pred = self.predict_noise( + model=model, + latents=latents, + timestep=timestep, + context=context, + attn_mask=attn_mask, + attn_mask_ctx=attn_mask_ctx, + ) + negative_noise_pred = fwd_with_temperature( + inputs={ + "model": model, + "latents": latents, + "timestep": timestep, + "context": context_null, + "attn_mask": attn_mask, + "attn_mask_ctx": attn_mask_ctx_null, + }, + model_fwd_func=self.predict_noise, + get_hooked_layer_func=lambda i: [ + model.transformer_blocks[i].attn.q, + model.transformer_blocks[i].cross_attn.q, + ], + layer_start_idx=15, + layer_end_idx=20, + ) + noise_pred = apg_forward( + pred_cond=positive_noise_pred, + pred_uncond=negative_noise_pred, + guidance_scale=cfg_scale, + momentum_buffer=momentum_buffer, + ) + return noise_pred + + def predict_noise(self, model, latents, timestep, context, attn_mask, attn_mask_ctx): + latents = latents.to(dtype=self.config.model_dtype, device=self.device) + + noise_pred = model( + x=latents, + timestep=timestep, + context=context, + attn_mask=attn_mask, + attn_mask_ctx=attn_mask_ctx, + ) + return noise_pred + + def decode_audio(self, latents: torch.Tensor, sample_rate=48000) -> List[torch.Tensor]: + self.load_models_to_device(["vae"]) + latents = latents.to(dtype=self.config.vae_dtype, device=self.device) + audios = self.vae.decode(latents, sr=sample_rate) + return audios + + @torch.no_grad() + def text2audio( + self, + prompt: str, + audio_duration: float, + lyrics: str = "", + cfg_scale: int = 15, + omega_scale: int = 10.0, + num_inference_steps: int = 60, + seed=None, + guidance_interval: float = 0.5, + progress_callback: Optional[Callable[[int, int, str], None]] = None, + ): + prompt_emb, prompt_attn_mask = self.encode_prompt(prompt) + prompt_emb_null, prompt_attn_mask_null = self.encode_prompt_null(prompt) + if len(lyrics.strip()) > 0: + lyric_token_idx, lyric_mask = self.tokenize_lyric(lyrics) + else: + lyric_token_idx = torch.zeros((1, 1), device=self.device, dtype=torch.long) + lyric_mask = torch.zeros((1, 1), device=self.device, dtype=torch.long) + + num_frames = int(audio_duration * 44100 / 512 / 8) + noise = self.generate_noise((1, 8, 16, num_frames), seed=seed, device="cpu", dtype=torch.float32).to( + self.device + ) + attn_mask = torch.ones(1, num_frames, device=self.device, dtype=self.dtype) + _, latents, sigmas, timesteps = self.prepare_latents( + latents=noise, + input_video=None, + denoising_strength=None, + num_inference_steps=num_inference_steps, + ) + # Initialize sampler + self.sampler.initialize(sigmas=sigmas) + # guidance interval + cfg_start_step = int(num_inference_steps * ((1 - guidance_interval) / 2)) + cfg_end_step = int(num_inference_steps * (guidance_interval / 2 + 0.5)) + momentum_buffer = MomentumBuffer() + + context, context_mask = self.dit.encode(prompt_emb, lyric_token_idx, prompt_attn_mask, lyric_mask) + context_null, context_mask_null = fwd_with_temperature( + inputs={ + "context_prompt": prompt_emb_null, + "context_lyric": lyric_token_idx, + "attn_mask_prompt": prompt_attn_mask_null, + "attn_mask_lyric": lyric_mask, + }, + model_fwd_func=self.dit.encode, + get_hooked_layer_func=lambda i: self.dit.lyric_encoder.encoders[i].self_attn.linear_q, + layer_start_idx=4, + layer_end_idx=6, + ) + + self.load_models_to_device(["dit"]) + hide_progress = dist.is_initialized() and dist.get_rank() != 0 + for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): + timestep = timestep.to(dtype=self.dtype, device=self.device) + # Classifier-free guidance + if cfg_start_step <= i < cfg_end_step: + noise_pred = self.predict_noise_with_cfg( + model=self.dit, + latents=latents, + timestep=timestep, + context=context, + context_null=context_null, + attn_mask=attn_mask, + attn_mask_ctx=context_mask, + attn_mask_ctx_null=context_mask_null, + cfg_scale=cfg_scale, + momentum_buffer=momentum_buffer, + ) + else: + noise_pred = self.predict_noise( + model=self.dit.decode, + latents=latents, + timestep=timestep, + context=context, + attn_mask=attn_mask, + attn_mask_ctx=context_mask, + ) + # Scheduler + dx: torch.Tensor = noise_pred * (self.sampler.sigmas[i + 1] - self.sampler.sigmas[i]) + dx_mean = dx.mean(dim=(1, 2, 3), keepdim=True) + latents = latents.to(dtype=torch.float32) + latents += (dx - dx_mean) * omega_scale + dx + latents = latents.to(dtype=noise_pred.dtype) + if progress_callback is not None: + progress_callback(i + 1, len(timesteps), "DENOISING") + + # Decode + return self.decode_audio(latents) + + def audio2audio(self): + raise NotImplementedError + + @classmethod + def from_pretrained(cls, model_path_or_config: ACEStepPipelineConfig) -> "ACEStepMusicPipeline": + if isinstance(model_path_or_config, str): + config = ACEStepPipelineConfig(model_path=model_path_or_config) + else: + config = model_path_or_config + + dit_state_dict = None + if dit_state_dict is None: + logger.info(f"loading dit state dict from {config.model_path} ...") + dit_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype) + + if config.t5_path is None: + config.t5_path = fetch_model("ACE-Step/ACE-Step-v1-3.5B", path="umt5-base/model.safetensors") + logger.info(f"loading t5 state dict from {config.t5_path} ...") + t5_state_dict = cls.load_model_checkpoint(config.t5_path, device="cpu", dtype=config.t5_dtype) + + if config.dcae_path is None: + config.dcae_path = fetch_model( + "ACE-Step/ACE-Step-v1-3.5B", path="music_dcae_f8c8/diffusion_pytorch_model.safetensors" + ) + logger.info(f"loading vae/dcae state dict from {config.dcae_path} ...") + dcae_state_dict = cls.load_model_checkpoint(config.dcae_path, device="cpu", dtype=config.vae_dtype) + + if config.vocoder_path is None: + config.vocoder_path = fetch_model( + "ACE-Step/ACE-Step-v1-3.5B", path="music_vocoder/diffusion_pytorch_model.safetensors" + ) + logger.info(f"loading vae/vocoder state dict from {config.vocoder_path} ...") + vocoder_state_dict = cls.load_model_checkpoint(config.vocoder_path, device="cpu", dtype=config.vae_dtype) + + state_dicts = ACEStateDicts( + model=dit_state_dict, + t5=t5_state_dict, + dcae=dcae_state_dict, + vocoder=vocoder_state_dict, + ) + return cls.from_state_dict(state_dicts, config) + + @classmethod + def from_state_dict(cls, state_dicts: ACEStateDicts, config: ACEStepPipelineConfig) -> "ACEStepMusicPipeline": + # if config.parallelism > 1: + # pipe = ParallelWrapper( + # cfg_degree=config.cfg_degree, + # sp_ulysses_degree=config.sp_ulysses_degree, + # sp_ring_degree=config.sp_ring_degree, + # tp_degree=config.tp_degree, + # use_fsdp=config.use_fsdp, + # ) + # pipe.load_module(cls._from_state_dict, state_dicts=state_dicts, config=config) + # else: + pipe = cls._from_state_dict(state_dicts, config) + return pipe + + @classmethod + def _from_state_dict(cls, state_dicts: ACEStateDicts, config: ACEStepPipelineConfig) -> "ACEStepMusicPipeline": + # default params from model config + dcae_config: dict = DCAE.get_model_config() + vocoder_config: dict = ADaMoSHiFiGANV1.get_model_config() + dit_config: dict = ACEStepDiT.get_model_config() + t5_config: dict = ACETextEncoder.get_model_config() + config.shift = dit_config.pop("shift", 3.0) + config.cfg_scale = dit_config.pop("cfg_scale", 15) + config.num_inference_steps = dit_config.pop("num_inference_steps", 60) + + init_device = "cpu" if config.offload_mode is not None else config.device + tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=256, clean="whitespace") + text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, config=t5_config, device=init_device, dtype=config.t5_dtype) + dcae = DCAE.from_state_dict(state_dicts.dcae, config=dcae_config, device=init_device, dtype=config.vae_dtype) + hifi_gan = ADaMoSHiFiGANV1.from_state_dict(state_dicts.vocoder, config=vocoder_config, device=init_device, dtype=config.vae_dtype) + vae = MusicDCAE(dcae=dcae, vocoder=hifi_gan) + + with LoRAContext(): + attn_kwargs = { + "attn_impl": config.dit_attn_impl, + "sparge_smooth_k": config.sparge_smooth_k, + "sparge_cdfthreshd": config.sparge_cdfthreshd, + "sparge_simthreshd1": config.sparge_simthreshd1, + "sparge_pvthreshd": config.sparge_pvthreshd, + } + dit = ACEStepDiT.from_state_dict( + state_dicts.model, + config=dit_config, + device=init_device, + dtype=config.model_dtype, + attn_kwargs=attn_kwargs, + ) + if config.use_fp8_linear: + enable_fp8_linear(dit) + + pipe = cls( + config=config, + tokenizer=tokenizer, + text_encoder=text_encoder, + dit=dit, + vae=vae, + ) + pipe.eval() + + if config.offload_mode is not None: + pipe.enable_cpu_offload(config.offload_mode) + + if config.model_dtype == torch.float8_e4m3fn: + pipe.dtype = torch.bfloat16 # compute dtype + pipe.enable_fp8_autocast( + model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear + ) + + if config.t5_dtype == torch.float8_e4m3fn: + pipe.dtype = torch.bfloat16 # compute dtype + pipe.enable_fp8_autocast( + model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear + ) + + if config.use_torch_compile: + pipe.compile() + return pipe + + def compile(self): + self.dit.compile_repeated_blocks(dynamic=True) diff --git a/diffsynth_engine/utils/constants.py b/diffsynth_engine/utils/constants.py index d001123a..6ebbba63 100644 --- a/diffsynth_engine/utils/constants.py +++ b/diffsynth_engine/utils/constants.py @@ -45,6 +45,11 @@ QWEN_IMAGE_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen_image_vae.json") QWEN_IMAGE_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen_image_vae_keymap.json") +ACE_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "ace", "t5.json") +ACE_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "ace", "dit.json") +ACE_VAE_DCAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "ace", "vae", "dcae.json") +ACE_VAE_VOCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "ace", "vae", "vocoder.json") + # data size KB = 1024 MB = 1024 * KB From c7222d7309253c3302573e93fa76817c0ebdea2e Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Fri, 12 Sep 2025 14:25:09 +0800 Subject: [PATCH 3/8] maybe finish dev, let's test --- diffsynth_engine/conf/models/ace/dit.json | 8 +--- diffsynth_engine/conf/models/ace/t5.json | 26 ++--------- diffsynth_engine/models/ace_step/ace_dit.py | 43 ++++++++++++++++++- .../models/ace_step/ace_lyric_encoder.py | 20 ++++----- diffsynth_engine/pipelines/ace_step.py | 6 +-- diffsynth_engine/utils/audio.py | 16 +++++++ examples/ace_text_to_music.py | 26 +++++++++++ 7 files changed, 102 insertions(+), 43 deletions(-) create mode 100644 diffsynth_engine/utils/audio.py create mode 100644 examples/ace_text_to_music.py diff --git a/diffsynth_engine/conf/models/ace/dit.json b/diffsynth_engine/conf/models/ace/dit.json index bc409f49..32994b54 100644 --- a/diffsynth_engine/conf/models/ace/dit.json +++ b/diffsynth_engine/conf/models/ace/dit.json @@ -1,5 +1,5 @@ { - "attention_head_dim": 128, + "head_dim": 128, "in_channels": 8, "inner_dim": 2560, "lyric_encoder_vocab_size": 6693, @@ -8,7 +8,7 @@ "max_position": 32768, "max_width": 32768, "mlp_ratio": 2.5, - "num_attention_heads": 20, + "num_heads": 20, "num_layers": 24, "out_channels": 8, "patch_size": [ @@ -25,9 +25,5 @@ 1024, 768 ], - "ssl_names": [ - "mert", - "m-hubert" - ], "text_embedding_dim": 768 } \ No newline at end of file diff --git a/diffsynth_engine/conf/models/ace/t5.json b/diffsynth_engine/conf/models/ace/t5.json index 0d6b2148..3ae0b6b4 100644 --- a/diffsynth_engine/conf/models/ace/t5.json +++ b/diffsynth_engine/conf/models/ace/t5.json @@ -1,25 +1,7 @@ { - "classifier_dropout": 0.0, - "d_ff": 2048, - "d_kv": 64, - "d_model": 768, - "decoder_start_token_id": 0, - "dense_act_fn": "gelu_new", - "dropout_rate": 0.1, - "eos_token_id": 1, - "feed_forward_proj": "gated-gelu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "is_gated_act": true, - "layer_norm_epsilon": 1e-06, - "num_decoder_layers": 12, + "dim_ffn": 2048, + "dim": 768, + "dim_attn": 768, "num_heads": 12, - "num_layers": 12, - "output_past": true, - "pad_token_id": 0, - "relative_attention_max_distance": 128, - "relative_attention_num_buckets": 32, - "scalable_attention": true, - "tie_word_embeddings": false, - "vocab_size": 256384 + "num_layers": 12 } \ No newline at end of file diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index bbd89ff9..0942e6e5 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -289,7 +289,48 @@ def forward(self, x, t): class ACEStepDiTStateDictConverter(StateDictConverter): - def convert(self, state_dict): # TODO + def convert(self, state_dict): + for key in list(state_dict.keys()): + # change all linear_q / linear_k / linear_v / linear_p to q / k / v / p + if "linear_q" in key: + new_key = key.replace("linear_q", "q") + state_dict[new_key] = state_dict.pop(key) + elif "linear_k" in key: + new_key = key.replace("linear_k", "k") + state_dict[new_key] = state_dict.pop(key) + elif "linear_v" in key: + new_key = key.replace("linear_v", "v") + state_dict[new_key] = state_dict.pop(key) + elif "linear_p" in key: + new_key = key.replace("linear_p", "p") + state_dict[new_key] = state_dict.pop(key) + elif "linear_out" in key: + new_key = key.replace("linear_out", "o") + state_dict[new_key] = state_dict.pop(key) + # change all to_q / to_k / to_v / to_out to q / k / v / o + elif "to_q" in key: + new_key = key.replace("to_q", "q") + state_dict[new_key] = state_dict.pop(key) + elif "to_k" in key: + new_key = key.replace("to_k", "k") + state_dict[new_key] = state_dict.pop(key) + elif "to_v" in key: + new_key = key.replace("to_v", "v") + state_dict[new_key] = state_dict.pop(key) + elif "to_out" in key: + new_key = key.replace("to_out", "o") + state_dict[new_key] = state_dict.pop(key) + # remove all add_{q/k/v}_proj + elif "add_q_proj" in key or "add_k_proj" in key or "add_v_proj" in key: + state_dict.pop(key) + # rename timestep_embedder.linear_1 into time_embedder.timestep_embedder.0 + elif "timestep_embedder.linear_1" in key: + new_key = key.replace("timestep_embedder.linear_1", "time_embedder.timestep_embedder.0") + state_dict[new_key] = state_dict.pop(key) + # rename timestep_embedder.linear_2 into time_embedder.timestep_embedder.2 + elif "timestep_embedder.linear_2" in key: + new_key = key.replace("timestep_embedder.linear_2", "time_embedder.timestep_embedder.2") + state_dict[new_key] = state_dict.pop(key) return state_dict diff --git a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py index 275c6a3c..bb8b71e8 100644 --- a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py +++ b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py @@ -26,11 +26,11 @@ class RelPositionMultiHeadedAttention(nn.Module): def __init__(self, n_head: int, n_feat: int): self.n_head = n_head self.d_k = n_feat // n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + self.q = nn.Linear(n_feat, n_feat) + self.k = nn.Linear(n_feat, n_feat) + self.v = nn.Linear(n_feat, n_feat) + self.o = nn.Linear(n_feat, n_feat) + self.p = nn.Linear(n_feat, n_feat, bias=False) self.pos_bias_u = nn.Parameter(torch.Tensor(n_head, self.d_k)) self.pos_bias_v = nn.Parameter(torch.Tensor(n_head, self.d_k)) @@ -63,10 +63,10 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor) -> Returns: torch.Tensor: Output tensor (#batch, time1, d_model). """ - q = self.linear_q(x) - k = self.linear_k(x) - v = self.linear_v(x) - p = self.linear_pos(pos_emb) + q = self.q(x) + k = self.k(x) + v = self.v(x) + p = self.p(pos_emb) q = rearrange(q, "b t (h d) -> b h t d", h=self.n_head) k = rearrange(k, "b t (h d) -> b h d t", h=self.n_head) v = rearrange(v, "b t (h d) -> b h t d", h=self.n_head) @@ -85,7 +85,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor) -> attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) x = torch.matmul(attn, v) x = rearrange(x, "b h t d -> b t (h d)") - return self.linear_out(x) + return self.o(x) class ConformerEncoderLayer(nn.Module): diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index f91db3e7..1f600ab7 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -306,9 +306,7 @@ def text2audio( lyric_mask = torch.zeros((1, 1), device=self.device, dtype=torch.long) num_frames = int(audio_duration * 44100 / 512 / 8) - noise = self.generate_noise((1, 8, 16, num_frames), seed=seed, device="cpu", dtype=torch.float32).to( - self.device - ) + noise = self.generate_noise((1, 8, 16, num_frames), seed=seed, device="cpu", dtype=torch.float32).to(self.device) attn_mask = torch.ones(1, num_frames, device=self.device, dtype=self.dtype) _, latents, sigmas, timesteps = self.prepare_latents( latents=noise, @@ -332,7 +330,7 @@ def text2audio( "attn_mask_lyric": lyric_mask, }, model_fwd_func=self.dit.encode, - get_hooked_layer_func=lambda i: self.dit.lyric_encoder.encoders[i].self_attn.linear_q, + get_hooked_layer_func=lambda i: self.dit.lyric_encoder.encoders[i].self_attn.q, layer_start_idx=4, layer_end_idx=6, ) diff --git a/diffsynth_engine/utils/audio.py b/diffsynth_engine/utils/audio.py new file mode 100644 index 00000000..da4e812a --- /dev/null +++ b/diffsynth_engine/utils/audio.py @@ -0,0 +1,16 @@ +from typing import List + +import torch +import torchaudio + + +def save_audio(audios: List[torch.Tensor], save_path: str, sample_rate: int = 48000, format: str = "wav"): + backend = "soundfile" if format != "ogg" else "sox" + output_audio_paths = [] + for i, audio in enumerate(audios): + output_audio_path = f"{save_path}_{i}.{format}" + torchaudio.save( + output_audio_path, audio, sample_rate=sample_rate, format=format, backend=backend + ) + output_audio_paths.append(output_audio_path) + return output_audio_paths \ No newline at end of file diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py new file mode 100644 index 00000000..4e4f86b9 --- /dev/null +++ b/examples/ace_text_to_music.py @@ -0,0 +1,26 @@ +import random + +from diffsynth_engine.configs import ACEStepPipelineConfig +from diffsynth_engine.pipelines.ace_step import ACEStepMusicPipeline +from diffsynth_engine.utils.download import fetch_model +from diffsynth_engine.utils.audio import save_audio + + +if __name__ == "__main__": + config = ACEStepPipelineConfig( + model_path=fetch_model( + model_uri="ACE-Step/ACE-Step-v1-3.5B", + path="ace_step_transformer/diffusion_pytorch_model.safetensors", + ), + ) + seed = random.randint(0, 2**32 - 1) + + pipe = ACEStepMusicPipeline.from_pretrained(config) + audio = pipe.text2audio( + prompt="pop, rap, electronic, blues, hip-house, rhythm and blues", + lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意", + audio_duration=170.63997916666668, + ) + save_audio(audio, f"tmp/ace_t2m_{seed}") + + del pipe From 01351b94cf76486b3f788e653f45522a12e9b8a8 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Fri, 12 Sep 2025 15:39:09 +0800 Subject: [PATCH 4/8] at least models can load now --- diffsynth_engine/conf/models/ace/dit.json | 12 ------ diffsynth_engine/conf/models/ace/t5.json | 8 ++-- .../conf/models/ace/vae/dcae.json | 1 - diffsynth_engine/models/ace_step/ace_dit.py | 41 ++++--------------- .../models/ace_step/ace_lyric_encoder.py | 1 + .../models/ace_step/ace_text_encoder.py | 21 +--------- diffsynth_engine/models/ace_step/vae/dcae.py | 2 + .../models/ace_step/vae/hifi_gan.py | 3 +- diffsynth_engine/pipelines/ace_step.py | 2 +- examples/ace_text_to_music.py | 5 ++- pyproject.toml | 5 +++ 11 files changed, 29 insertions(+), 72 deletions(-) diff --git a/diffsynth_engine/conf/models/ace/dit.json b/diffsynth_engine/conf/models/ace/dit.json index 32994b54..208531b9 100644 --- a/diffsynth_engine/conf/models/ace/dit.json +++ b/diffsynth_engine/conf/models/ace/dit.json @@ -1,12 +1,8 @@ { "head_dim": 128, - "in_channels": 8, - "inner_dim": 2560, "lyric_encoder_vocab_size": 6693, "lyric_hidden_size": 1024, - "max_height": 16, "max_position": 32768, - "max_width": 32768, "mlp_ratio": 2.5, "num_heads": 20, "num_layers": 24, @@ -17,13 +13,5 @@ ], "rope_theta": 1000000.0, "speaker_embedding_dim": 512, - "ssl_encoder_depths": [ - 8, - 8 - ], - "ssl_latent_dims": [ - 1024, - 768 - ], "text_embedding_dim": 768 } \ No newline at end of file diff --git a/diffsynth_engine/conf/models/ace/t5.json b/diffsynth_engine/conf/models/ace/t5.json index 3ae0b6b4..be589726 100644 --- a/diffsynth_engine/conf/models/ace/t5.json +++ b/diffsynth_engine/conf/models/ace/t5.json @@ -1,7 +1,7 @@ { - "dim_ffn": 2048, - "dim": 768, - "dim_attn": 768, + "d_ff": 2048, + "embed_dim": 768, "num_heads": 12, - "num_layers": 12 + "num_encoder_layers": 12, + "vocab_size": 256384 } \ No newline at end of file diff --git a/diffsynth_engine/conf/models/ace/vae/dcae.json b/diffsynth_engine/conf/models/ace/vae/dcae.json index a514a707..b7dac58c 100644 --- a/diffsynth_engine/conf/models/ace/vae/dcae.json +++ b/diffsynth_engine/conf/models/ace/vae/dcae.json @@ -61,6 +61,5 @@ ], "in_channels": 2, "latent_channels": 8, - "scaling_factor": 0.41407, "upsample_block_type": "interpolate" } \ No newline at end of file diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 0942e6e5..08ca2f5d 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -19,7 +19,7 @@ class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"): super().__init__() self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)) - self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device) + self._set_cos_sin_cache(seq_len=max_position_embeddings) def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len @@ -302,7 +302,7 @@ def convert(self, state_dict): new_key = key.replace("linear_v", "v") state_dict[new_key] = state_dict.pop(key) elif "linear_p" in key: - new_key = key.replace("linear_p", "p") + new_key = key.replace("linear_pos", "p") state_dict[new_key] = state_dict.pop(key) elif "linear_out" in key: new_key = key.replace("linear_out", "o") @@ -317,11 +317,14 @@ def convert(self, state_dict): elif "to_v" in key: new_key = key.replace("to_v", "v") state_dict[new_key] = state_dict.pop(key) - elif "to_out" in key: - new_key = key.replace("to_out", "o") + elif "to_out.0" in key: + new_key = key.replace("to_out.0", "o") state_dict[new_key] = state_dict.pop(key) # remove all add_{q/k/v}_proj - elif "add_q_proj" in key or "add_k_proj" in key or "add_v_proj" in key: + elif "add_q_proj" in key or "add_k_proj" in key or "add_v_proj" in key or "to_add_out" in key: + state_dict.pop(key) + # remove all projectors. + elif "projectors" in key: state_dict.pop(key) # rename timestep_embedder.linear_1 into time_embedder.timestep_embedder.0 elif "timestep_embedder.linear_1" in key: @@ -349,12 +352,9 @@ def __init__( rope_theta: float = 1000000.0, speaker_embedding_dim: int = 512, text_embedding_dim: int = 768, - ssl_latent_dims: List[int] = [1024, 768], lyric_encoder_vocab_size: int = 6693, lyric_hidden_size: int = 1024, patch_size: List[int] = [16, 1], - max_height: int = 16, - max_width: int = 32768, attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, @@ -390,30 +390,7 @@ def __init__( self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, device=device, dtype=dtype) self.lyric_encoder = ConformerEncoder(input_size=lyric_hidden_size) self.lyric_proj = nn.Linear(lyric_hidden_size, inner_dim, device=device, dtype=dtype) - - projector_dim = 2 * inner_dim - self.projectors = nn.ModuleList( - [ - nn.Sequential( - nn.Linear(inner_dim, projector_dim, device=device, dtype=dtype), - nn.SiLU(), - nn.Linear(projector_dim, projector_dim, device=device, dtype=dtype), - nn.SiLU(), - nn.Linear(projector_dim, ssl_dim, device=device, dtype=dtype), - ) - for ssl_dim in ssl_latent_dims - ] - ) - - self.proj_in = PatchEmbed( - height=max_height, - width=max_width, - patch_size=patch_size, - embed_dim=inner_dim, - device=device, - dtype=dtype, - ) - + self.proj_in = PatchEmbed(patch_size=patch_size, embed_dim=inner_dim, device=device, dtype=dtype) self.final_layer = FinalLayer(inner_dim, patch_size=patch_size, out_channels=out_channels) def forward_lyric_encoder(self, context_lyric: torch.LongTensor, attn_mask_lyric: torch.LongTensor): diff --git a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py index bb8b71e8..dcbf670b 100644 --- a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py +++ b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py @@ -24,6 +24,7 @@ class RelPositionMultiHeadedAttention(nn.Module): """ def __init__(self, n_head: int, n_feat: int): + super().__init__() self.n_head = n_head self.d_k = n_feat // n_head self.q = nn.Linear(n_feat, n_feat) diff --git a/diffsynth_engine/models/ace_step/ace_text_encoder.py b/diffsynth_engine/models/ace_step/ace_text_encoder.py index 3135479f..7b3a0614 100644 --- a/diffsynth_engine/models/ace_step/ace_text_encoder.py +++ b/diffsynth_engine/models/ace_step/ace_text_encoder.py @@ -1,27 +1,10 @@ -from typing import Dict, Any import json -import torch - -from diffsynth_engine.models.wan.wan_text_encoder import WanTextEncoder +from diffsynth_engine.models.text_encoder.t5 import T5EncoderModel from diffsynth_engine.utils.constants import ACE_TEXT_ENCODER_CONFIG_FILE -class ACETextEncoder(WanTextEncoder): - @classmethod - def from_state_dict( - cls, - state_dict: Dict[str, torch.Tensor], - config: Dict[str, Any], - device: str = "cuda:0", - dtype: torch.dtype = torch.float32, - ) -> "ACETextEncoder": - model = cls(**config, device="meta", dtype=dtype) - model.requires_grad_(False) - model.load_state_dict(state_dict, assign=True) - model.to(device=device, dtype=dtype, non_blocking=True) - return model - +class ACETextEncoder(T5EncoderModel): @staticmethod def get_model_config() -> dict: config_file = ACE_TEXT_ENCODER_CONFIG_FILE diff --git a/diffsynth_engine/models/ace_step/vae/dcae.py b/diffsynth_engine/models/ace_step/vae/dcae.py index d56e9626..b519885a 100644 --- a/diffsynth_engine/models/ace_step/vae/dcae.py +++ b/diffsynth_engine/models/ace_step/vae/dcae.py @@ -595,6 +595,8 @@ def __init__( downsample_block_type: str = "Conv", decoder_norm_types: Union[str, Tuple[str]] = "rms_norm", decoder_act_fns: Union[str, Tuple[str]] = "silu", + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ) -> None: super().__init__() diff --git a/diffsynth_engine/models/ace_step/vae/hifi_gan.py b/diffsynth_engine/models/ace_step/vae/hifi_gan.py index 536b27e3..d6c22b02 100644 --- a/diffsynth_engine/models/ace_step/vae/hifi_gan.py +++ b/diffsynth_engine/models/ace_step/vae/hifi_gan.py @@ -576,6 +576,8 @@ def __init__( f_min: int = 40, f_max: int = 16000, n_mels: int = 128, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -609,7 +611,6 @@ def __init__( f_max=f_max, n_mels=n_mels, ) - self.eval() @torch.no_grad() def decode(self, mel): diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 1f600ab7..e7af4978 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -444,7 +444,7 @@ def _from_state_dict(cls, state_dicts: ACEStateDicts, config: ACEStepPipelineCon init_device = "cpu" if config.offload_mode is not None else config.device tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=256, clean="whitespace") - text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, config=t5_config, device=init_device, dtype=config.t5_dtype) + text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) dcae = DCAE.from_state_dict(state_dicts.dcae, config=dcae_config, device=init_device, dtype=config.vae_dtype) hifi_gan = ADaMoSHiFiGANV1.from_state_dict(state_dicts.vocoder, config=vocoder_config, device=init_device, dtype=config.vae_dtype) vae = MusicDCAE(dcae=dcae, vocoder=hifi_gan) diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index 4e4f86b9..f817485a 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -1,4 +1,4 @@ -import random +# import random from diffsynth_engine.configs import ACEStepPipelineConfig from diffsynth_engine.pipelines.ace_step import ACEStepMusicPipeline @@ -12,8 +12,9 @@ model_uri="ACE-Step/ACE-Step-v1-3.5B", path="ace_step_transformer/diffusion_pytorch_model.safetensors", ), + device="cuda:2", ) - seed = random.randint(0, 2**32 - 1) + seed = 3299954530 pipe = ACEStepMusicPipeline.from_pretrained(config) audio = pipe.text2audio( diff --git a/pyproject.toml b/pyproject.toml index 48eb38fd..8b4c1197 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,11 @@ dependencies = [ "librosa", "scikit-image", "trimesh" + "py3langid", + "pypinyin", + "hangul_romanize", + "num2words", + "spacy" ] [project.optional-dependencies] From 1896d1246e92cc0513af18671c40e3c25b2f4378 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Fri, 12 Sep 2025 17:19:28 +0800 Subject: [PATCH 5/8] fix a bunch of issue, but still a bunch of issue --- .../conf/models/ace/vae/dcae.json | 1 - .../conf/models/ace/vae/vocoder.json | 1 - diffsynth_engine/models/ace_step/ace_dit.py | 4 +- .../models/ace_step/ace_lyric_encoder.py | 2 +- .../models/ace_step/ace_text_encoder.py | 2 +- diffsynth_engine/models/ace_step/vae/dcae.py | 165 ++++++++++-------- .../models/ace_step/vae/hifi_gan.py | 128 +++++++------- .../models/ace_step/vae/music_dcae.py | 14 +- diffsynth_engine/pipelines/ace_step.py | 19 +- 9 files changed, 169 insertions(+), 167 deletions(-) diff --git a/diffsynth_engine/conf/models/ace/vae/dcae.json b/diffsynth_engine/conf/models/ace/vae/dcae.json index b7dac58c..3f2bb1c1 100644 --- a/diffsynth_engine/conf/models/ace/vae/dcae.json +++ b/diffsynth_engine/conf/models/ace/vae/dcae.json @@ -30,7 +30,6 @@ 5 ] ], - "downsample_block_type": "Conv", "encoder_block_out_channels": [ 128, 256, diff --git a/diffsynth_engine/conf/models/ace/vae/vocoder.json b/diffsynth_engine/conf/models/ace/vae/vocoder.json index 771e3446..ad768532 100644 --- a/diffsynth_engine/conf/models/ace/vae/vocoder.json +++ b/diffsynth_engine/conf/models/ace/vae/vocoder.json @@ -11,7 +11,6 @@ 384, 512 ], - "drop_path_rate": 0.0, "f_max": 16000, "f_min": 40, "hop_length": 512, diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 08ca2f5d..d9a91bcd 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -225,7 +225,7 @@ def __init__( def forward(self, x, context, t_mod, freqs, freqs_ctx, attn_mask, attn_mask_ctx): # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ - t.squeeze(1) for t in (self.scale_shift_table[None] + t_mod).chunk(6, dim=1) + t.squeeze(1) for t in (self.scale_shift_table[None] + rearrange(t_mod, "b (c d) -> b c d", c=6)).chunk(6, dim=1) ] input_x = modulate(self.norm1(x), shift_msa, scale_msa) x += gate_msa * self.attn(input_x, freqs, attn_mask) @@ -441,7 +441,7 @@ def decode( attn_mask: torch.Tensor, attn_mask_ctx: torch.Tensor, ): - t = self.time_embedder(timestep) + t = self.time_embedder(timestep, x.dtype) t_mod = self.t_block(t) x = self.proj_in(x) diff --git a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py index dcbf670b..e1005ade 100644 --- a/diffsynth_engine/models/ace_step/ace_lyric_encoder.py +++ b/diffsynth_engine/models/ace_step/ace_lyric_encoder.py @@ -224,7 +224,7 @@ def __init__(self, idim: int, odim: int): nn.Linear(idim, odim), nn.LayerNorm(odim, eps=1e-5), ) - self.pos_enc = (EspnetRelPositionalEncoding(odim),) + self.pos_enc = EspnetRelPositionalEncoding(odim) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Input x. diff --git a/diffsynth_engine/models/ace_step/ace_text_encoder.py b/diffsynth_engine/models/ace_step/ace_text_encoder.py index 7b3a0614..55e438da 100644 --- a/diffsynth_engine/models/ace_step/ace_text_encoder.py +++ b/diffsynth_engine/models/ace_step/ace_text_encoder.py @@ -5,7 +5,7 @@ class ACETextEncoder(T5EncoderModel): - @staticmethod + @staticmethod # TODO: remove relative_position_embedding def get_model_config() -> dict: config_file = ACE_TEXT_ENCODER_CONFIG_FILE with open(config_file, "r") as f: diff --git a/diffsynth_engine/models/ace_step/vae/dcae.py b/diffsynth_engine/models/ace_step/vae/dcae.py index b519885a..44cdf7f2 100644 --- a/diffsynth_engine/models/ace_step/vae/dcae.py +++ b/diffsynth_engine/models/ace_step/vae/dcae.py @@ -7,35 +7,18 @@ from diffsynth_engine.models.base import PreTrainedModel from diffsynth_engine.utils.constants import ACE_VAE_DCAE_CONFIG_FILE - +# TODO: this is still shitty. need modification class RMSNorm(nn.RMSNorm): - def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False): - super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) - if elementwise_affine: - self.bias = nn.Parameter(torch.empty(dim)) if bias else None + def __init__(self, dim, eps=1e-5, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16): + super().__init__(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) + self.bias = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) def forward(self, x): - x = super().forward(x) - if self.elementwise_affine: - if self.bias is not None: - x += self.bias, + x = super().forward(x) + self.bias return x -def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5): - if norm_type == "batch_norm": - return nn.BatchNorm2d(num_features) - elif norm_type == "group_norm": - return nn.GroupNorm(num_groups, num_features) - elif norm_type == "layer_norm": - return nn.LayerNorm(num_features) - elif norm_type == "rms_norm": - return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True) - else: - raise ValueError(f"Unknown normalization type: {norm_type}") - - def get_activation(activation_type): if activation_type == "relu": return nn.ReLU() @@ -56,14 +39,16 @@ def __init__( out_channels: int, norm_type: str = "batch_norm", act_fn: str = "relu6", + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ) -> None: super().__init__() self.norm_type = norm_type self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity() - self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) - self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False) - self.norm = get_normalization(norm_type, out_channels) + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1, device=device, dtype=dtype) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False, device=device, dtype=dtype) + self.norm = RMSNorm(out_channels, device=device, dtype=dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states @@ -85,6 +70,8 @@ def __init__( in_channels: int, num_attention_heads: int, kernel_size: int, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ) -> None: super().__init__() @@ -96,8 +83,10 @@ def __init__( padding=kernel_size // 2, groups=channels, bias=False, + device=device, + dtype=dtype, ) - self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False, device=device, dtype=dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.proj_in(hidden_states) @@ -116,6 +105,8 @@ def __init__( kernel_sizes: tuple = (5,), eps: float = 1e-15, residual_connection: bool = False, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -131,19 +122,19 @@ def __init__( ) inner_dim = num_attention_heads * attention_head_dim - self.to_q = nn.Linear(in_channels, inner_dim, bias=False) - self.to_k = nn.Linear(in_channels, inner_dim, bias=False) - self.to_v = nn.Linear(in_channels, inner_dim, bias=False) + self.to_q = nn.Linear(in_channels, inner_dim, bias=False, device=device, dtype=dtype) + self.to_k = nn.Linear(in_channels, inner_dim, bias=False, device=device, dtype=dtype) + self.to_v = nn.Linear(in_channels, inner_dim, bias=False, device=device, dtype=dtype) self.to_qkv_multiscale = nn.ModuleList() for kernel_size in kernel_sizes: self.to_qkv_multiscale.append( - SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size) + SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size, device=device, dtype=dtype) ) self.nonlinearity = nn.ReLU() - self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False) - self.norm_out = get_normalization(norm_type, out_channels) + self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False, device=device, dtype=dtype) + self.norm_out = RMSNorm(out_channels, device=device, dtype=dtype) def apply_linear_attention(self, query, key, value): value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) @@ -224,6 +215,8 @@ def __init__( attention_head_dim: int = 32, qkv_multiscales: tuple = (5,), norm_type: str = "batch_norm", + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ) -> None: super().__init__() @@ -235,12 +228,16 @@ def __init__( norm_type=norm_type, kernel_sizes=qkv_multiscales, residual_connection=True, + device=device, + dtype=dtype, ) self.conv_out = GLUMBConv( in_channels=in_channels, out_channels=in_channels, norm_type="rms_norm", + device=device, + dtype=dtype, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -257,6 +254,8 @@ def __init__( expand_ratio: float = 4, norm_type: str = None, residual_connection: bool = True, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ) -> None: super().__init__() @@ -265,13 +264,13 @@ def __init__( self.residual_connection = residual_connection self.nonlinearity = nn.SiLU() - self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) - self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2) - self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) + self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0, device=device, dtype=dtype) + self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2, device=device, dtype=dtype) + self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False, device=device, dtype=dtype) self.norm = None if norm_type == "rms_norm": - self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) + self.norm = RMSNorm(out_channels) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.residual_connection: @@ -303,16 +302,20 @@ def get_block( attention_head_dim: int, norm_type: str, act_fn: str, - qkv_mutliscales: tuple = (), + qkv_multiscales: tuple = (), + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): if block_type == "ResBlock": - block = ResBlock(in_channels, out_channels, norm_type, act_fn) + block = ResBlock(in_channels, out_channels, norm_type, act_fn, device=device, dtype=dtype) elif block_type == "EfficientViTBlock": block = EfficientViTBlock( in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, - qkv_multiscales=qkv_mutliscales + qkv_multiscales=qkv_multiscales, + device=device, + dtype=dtype, ) else: raise ValueError(f"Block with {block_type=} is not supported.") @@ -321,7 +324,7 @@ def get_block( class DCDownBlock2d(nn.Module): - def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None: + def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16) -> None: super().__init__() self.downsample = downsample @@ -341,6 +344,8 @@ def __init__(self, in_channels: int, out_channels: int, downsample: bool = False kernel_size=3, stride=self.stride, padding=1, + device=device, + dtype=dtype, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -367,6 +372,8 @@ def __init__( interpolate: bool = False, shortcut: bool = True, interpolation_mode: str = "nearest", + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ) -> None: super().__init__() @@ -380,7 +387,7 @@ def __init__( if not interpolate: out_channels = out_channels * out_ratio - self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1, device=device, dtype=dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.interpolate: @@ -406,33 +413,32 @@ def __init__( in_channels: int, latent_channels: int, attention_head_dim: int = 32, - block_type: str | tuple = "ResBlock", - block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024), - layers_per_block: tuple = (2, 2, 2, 2, 2, 2), - qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)), - downsample_block_type: str = "pixel_unshuffle", - out_shortcut: bool = True, + block_type: Tuple[str] = ("ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"), + block_out_channels: Tuple[int] = (128, 256, 512, 1024), + layers_per_block: Tuple[int] = (2, 2, 3, 3), + qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)), + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() - num_blocks = len(block_out_channels) - if isinstance(block_type, str): - block_type = (block_type,) * num_blocks - if layers_per_block[0] > 0: self.conv_in = nn.Conv2d( in_channels, - block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], + block_out_channels[0], kernel_size=3, - stride=1, padding=1, + device=device, + dtype=dtype, ) else: self.conv_in = DCDownBlock2d( in_channels=in_channels, - out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], - downsample=downsample_block_type == "pixel_unshuffle", + out_channels=block_out_channels[0], + downsample=False, shortcut=False, + device=device, + dtype=dtype, ) down_blocks = [] @@ -447,7 +453,9 @@ def __init__( attention_head_dim=attention_head_dim, norm_type="rms_norm", act_fn="silu", - qkv_mutliscales=qkv_multiscales[i], + qkv_multiscales=qkv_multiscales[i], + device=device, + dtype=dtype, ) down_block_list.append(block) @@ -455,32 +463,27 @@ def __init__( downsample_block = DCDownBlock2d( in_channels=out_channel, out_channels=block_out_channels[i + 1], - downsample=downsample_block_type == "pixel_unshuffle", + downsample=False, shortcut=True, + device=device, + dtype=dtype, ) down_block_list.append(downsample_block) down_blocks.append(nn.Sequential(*down_block_list)) self.down_blocks = nn.ModuleList(down_blocks) - - self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1) - - self.out_shortcut = out_shortcut - if out_shortcut: - self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1, device=device, dtype=dtype) + self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) for down_block in self.down_blocks: hidden_states = down_block(hidden_states) - if self.out_shortcut: - x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) - x = x.mean(dim=2) - hidden_states = self.conv_out(hidden_states) + x - else: - hidden_states = self.conv_out(hidden_states) + x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) + x = x.mean(dim=2) + hidden_states = self.conv_out(hidden_states) + x return hidden_states @@ -499,6 +502,8 @@ def __init__( act_fn: str | tuple = "silu", upsample_block_type: str = "pixel_shuffle", in_shortcut: bool = True, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -527,6 +532,8 @@ def __init__( out_channel, interpolate=upsample_block_type == "interpolate", shortcut=True, + device=device, + dtype=dtype, ) up_block_list.append(upsample_block) @@ -538,7 +545,9 @@ def __init__( attention_head_dim=attention_head_dim, norm_type=norm_type[i], act_fn=act_fn[i], - qkv_mutliscales=qkv_multiscales[i], + qkv_multiscales=qkv_multiscales[i], + device=device, + dtype=dtype, ) up_block_list.append(block) @@ -548,15 +557,15 @@ def __init__( channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1] - self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True) + self.norm_out = RMSNorm(channels, device=device, dtype=dtype) self.conv_act = nn.ReLU() self.conv_out = None if layers_per_block[0] > 0: - self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1) + self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1, device=device, dtype=dtype) else: self.conv_out = DCUpBlock2d( - channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False + channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False, device=device, dtype=dtype ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -583,8 +592,8 @@ def __init__( in_channels: int = 2, latent_channels: int = 8, attention_head_dim: int = 32, - encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"], - decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"], + encoder_block_types: Tuple[str] = ("ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"), + decoder_block_types: Tuple[str] = ("ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"), encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024), encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3), @@ -592,7 +601,6 @@ def __init__( encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)), decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)), upsample_block_type: str = "interpolate", - downsample_block_type: str = "Conv", decoder_norm_types: Union[str, Tuple[str]] = "rms_norm", decoder_act_fns: Union[str, Tuple[str]] = "silu", device: str = "cuda:0", @@ -608,7 +616,8 @@ def __init__( block_out_channels=encoder_block_out_channels, layers_per_block=encoder_layers_per_block, qkv_multiscales=encoder_qkv_multiscales, - downsample_block_type=downsample_block_type, + device=device, + dtype=dtype, ) self.decoder = Decoder( @@ -622,6 +631,8 @@ def __init__( norm_type=decoder_norm_types, act_fn=decoder_act_fns, upsample_block_type=upsample_block_type, + device=device, + dtype=dtype, ) diff --git a/diffsynth_engine/models/ace_step/vae/hifi_gan.py b/diffsynth_engine/models/ace_step/vae/hifi_gan.py index d6c22b02..5a6669d2 100644 --- a/diffsynth_engine/models/ace_step/vae/hifi_gan.py +++ b/diffsynth_engine/models/ace_step/vae/hifi_gan.py @@ -12,7 +12,7 @@ from diffsynth_engine.models.base import PreTrainedModel from diffsynth_engine.utils.constants import ACE_VAE_VOCODER_CONFIG_FILE - +# TODO: this is still shitty. need modification class LinearSpectrogram(nn.Module): def __init__( @@ -22,6 +22,8 @@ def __init__( hop_length=512, center=False, mode="pow2_sqrt", + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -31,13 +33,13 @@ def __init__( self.center = center self.mode = mode - self.register_buffer("window", torch.hann_window(win_length)) + self.register_buffer("window", torch.hann_window(win_length, device=device, dtype=dtype)) def forward(self, y: torch.Tensor) -> torch.Tensor: if y.ndim == 3: y = y.squeeze(1) - y = torch.nn.functional.pad( + y = F.pad( y.unsqueeze(1), ( (self.win_length - self.hop_length) // 2, @@ -77,6 +79,8 @@ def __init__( center=False, f_min=0.0, f_max=None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -89,7 +93,7 @@ def __init__( self.f_min = f_min self.f_max = f_max or sample_rate // 2 - self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center, device=device, dtype=dtype) self.mel_scale = MelScale( self.n_mels, self.sample_rate, @@ -117,42 +121,6 @@ def forward(self, x: torch.Tensor, return_linear: bool = False) -> torch.Tensor: return x -def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - - """ # noqa: E501 - - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and scale_by_keep: - random_tensor.div_(keep_prob) - return x * random_tensor - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 - - def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) - - def extra_repr(self): - return f"drop_prob={round(self.drop_prob, 3):0.3f}" - - class LayerNorm(nn.Module): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with @@ -160,10 +128,10 @@ class LayerNorm(nn.Module): with shape (batch_size, channels, height, width). """ # noqa: E501 - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16): super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.weight = nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype)) + self.bias = nn.Parameter(torch.zeros(normalized_shape, device=device, dtype=dtype)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: @@ -199,11 +167,12 @@ class ConvNeXtBlock(nn.Module): def __init__( self, dim: int, - drop_path: float = 0.0, layer_scale_init_value: float = 1e-6, mlp_ratio: float = 4.0, kernel_size: int = 7, dilation: int = 1, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -213,17 +182,18 @@ def __init__( kernel_size=kernel_size, padding=int(dilation * (kernel_size - 1) / 2), groups=dim, + device=device, + dtype=dtype, ) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim)) # pointwise/1x1 convs, implemented with linear layers + self.norm = LayerNorm(dim, eps=1e-6, device=device, dtype=dtype) + self.pwconv1 = nn.Linear(dim, int(mlp_ratio * dim), device=device, dtype=dtype) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() - self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim, device=device, dtype=dtype) self.gamma = ( - nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + nn.Parameter(layer_scale_init_value * torch.ones((dim), device=device, dtype=dtype)) if layer_scale_init_value > 0 else None ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x, apply_residual: bool = True): input = x @@ -239,7 +209,6 @@ def forward(self, x, apply_residual: bool = True): x = self.gamma * x x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) - x = self.drop_path(x) if apply_residual: x = input + x @@ -267,9 +236,10 @@ def __init__( input_channels=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], - drop_path_rate=0.0, layer_scale_init_value=1e-6, kernel_sizes: Tuple[int] = (7,), + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() assert len(depths) == len(dims) @@ -282,26 +252,27 @@ def __init__( kernel_size=7, padding=3, padding_mode="replicate", + device=device, + dtype=dtype, ), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first", device=device, dtype=dtype), ) self.channel_layers.append(stem) for i in range(len(depths) - 1): mid_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), + LayerNorm(dims[i], eps=1e-6, data_format="channels_first", device=device, dtype=dtype), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=1, device=device, dtype=dtype), ) self.channel_layers.append(mid_layer) block_fn = ( - partial(ConvNeXtBlock, kernel_size=kernel_sizes[0]) + partial(ConvNeXtBlock, kernel_size=kernel_sizes[0], device=device, dtype=dtype) if len(kernel_sizes) == 1 - else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes) + else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes, device=device, dtype=dtype) ) self.stages = nn.ModuleList() - drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(len(depths)): @@ -309,16 +280,15 @@ def __init__( *[ block_fn( dim=dims[i], - drop_path=drop_path_rates[cur + j], layer_scale_init_value=layer_scale_init_value, ) - for j in range(depths[i]) + for _ in range(depths[i]) ] ) self.stages.append(stage) cur += depths[i] - self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first", device=device, dtype=dtype) self.apply(self._init_weights) def _init_weights(self, m): @@ -348,7 +318,7 @@ def get_padding(kernel_size, dilation=1): class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16): super().__init__() self.convs1 = nn.ModuleList( @@ -361,6 +331,8 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]), + device=device, + dtype=dtype, ) ), weight_norm( @@ -371,6 +343,8 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]), + device=device, + dtype=dtype, ) ), weight_norm( @@ -381,6 +355,8 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]), + device=device, + dtype=dtype, ) ), ] @@ -397,6 +373,8 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 1, dilation=1, padding=get_padding(kernel_size, 1), + device=device, + dtype=dtype, ) ), weight_norm( @@ -407,6 +385,8 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 1, dilation=1, padding=get_padding(kernel_size, 1), + device=device, + dtype=dtype, ) ), weight_norm( @@ -417,6 +397,8 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 1, dilation=1, padding=get_padding(kernel_size, 1), + device=device, + dtype=dtype, ) ), ] @@ -448,6 +430,8 @@ def __init__( pre_conv_kernel_size: int = 7, post_conv_kernel_size: int = 7, post_activation: Callable = partial(nn.SiLU, inplace=True), + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, ): super().__init__() @@ -460,6 +444,8 @@ def __init__( pre_conv_kernel_size, 1, padding=get_padding(pre_conv_kernel_size), + device=device, + dtype=dtype, ) ) @@ -480,6 +466,8 @@ def __init__( k, u, padding=(k - u) // 2, + device=device, + dtype=dtype, ) ) ) @@ -496,16 +484,18 @@ def __init__( kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2, + device=device, + dtype=dtype, ) ) else: - self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=1)) + self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=1, device=device, dtype=dtype)) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): - self.resblocks.append(ResBlock1(ch, k, d)) + self.resblocks.append(ResBlock1(ch, k, d, device=device, dtype=dtype)) self.activation_post = post_activation() self.conv_post = weight_norm( @@ -515,6 +505,8 @@ def __init__( post_conv_kernel_size, 1, padding=get_padding(post_conv_kernel_size), + device=device, + dtype=dtype, ) ) self.ups.apply(init_weights) @@ -553,7 +545,6 @@ def __init__( input_channels: int = 128, depths: List[int] = [3, 3, 9, 3], dims: List[int] = [128, 256, 384, 512], - drop_path_rate: float = 0.0, kernel_sizes: Tuple[int] = (7,), upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2), upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4), @@ -581,12 +572,14 @@ def __init__( ): super().__init__() + self.sampling_rate = sampling_rate self.backbone = ConvNeXtEncoder( input_channels=input_channels, depths=depths, dims=dims, - drop_path_rate=drop_path_rate, kernel_sizes=kernel_sizes, + device=device, + dtype=dtype, ) self.head = HiFiGANGenerator( @@ -600,8 +593,9 @@ def __init__( use_template=use_template, pre_conv_kernel_size=pre_conv_kernel_size, post_conv_kernel_size=post_conv_kernel_size, + device=device, + dtype=dtype, ) - self.sampling_rate = sampling_rate self.mel_transform = LogMelSpectrogram( sample_rate=sampling_rate, n_fft=n_fft, @@ -610,6 +604,8 @@ def __init__( f_min=f_min, f_max=f_max, n_mels=n_mels, + device=device, + dtype=dtype, ) @torch.no_grad() diff --git a/diffsynth_engine/models/ace_step/vae/music_dcae.py b/diffsynth_engine/models/ace_step/vae/music_dcae.py index b6c9fda3..c7157252 100644 --- a/diffsynth_engine/models/ace_step/vae/music_dcae.py +++ b/diffsynth_engine/models/ace_step/vae/music_dcae.py @@ -5,7 +5,7 @@ from diffsynth_engine.models.base import PreTrainedModel from .hifi_gan import ADaMoSHiFiGANV1 -from .dcae import DCAE # TODO: rewrite above 2 files +from .dcae import DCAE class MusicDCAE(PreTrainedModel): @@ -27,16 +27,10 @@ def __init__( self.scale_factor = 0.1786 self.shift_factor = -1.9091 - def load_audio(self, audio_path): - audio, sr = torchaudio.load(audio_path) - if audio.shape[0] == 1: - audio = audio.repeat(2, 1) - return audio, sr - def forward_mel(self, audios): mels = [] - for i in range(len(audios)): - image = self.vocoder.mel_transform(audios[i]) + for audio in audios: + image = self.vocoder.mel_transform(audio) mels.append(image) mels = torch.stack(mels) return mels @@ -60,7 +54,7 @@ def encode(self, audios: torch.Tensor, sr: int): mels = self.transform(mels) latents = [] for mel in mels: - latent = self.dcae.encoder(mel.unsqueeze(0)) + latent = self.dcae.encoder(mel[None]) latents.append(latent) latents = torch.cat(latents, dim=0) latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index e7af4978..91730405 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -155,20 +155,23 @@ def unload_loras(self): def encode_prompt(self, prompt): self.load_models_to_device(["text_encoder"]) ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) - ids = ids.to(self.device) - mask = mask.to(self.device) + ids = ids.to(self.device)[:, :17] + mask = mask.to(self.device)[:, :17] prompt_emb = self.text_encoder(ids, mask) return prompt_emb, mask def encode_prompt_null(self, prompt): self.load_models_to_device(["text_encoder"]) ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) - ids = ids.to(self.device) - mask = mask.to(self.device) + ids = ids.to(self.device)[:, :17] + mask = mask.to(self.device)[:, :17] # for test purpose. I don't know why this tokenizer will enforce padding. prompt_emb = fwd_with_temperature( - inputs=(ids, mask), + inputs={ + "input_ids": ids, + "attention_mask": mask + }, model_fwd_func=self.text_encoder, - get_hooked_layer_func=lambda i: self.text_encoder.blocks[i].attn.q, + get_hooked_layer_func=lambda i: self.text_encoder.encoders[i].attn.to_q, layer_start_idx=4, layer_end_idx=6, ) @@ -310,7 +313,7 @@ def text2audio( attn_mask = torch.ones(1, num_frames, device=self.device, dtype=self.dtype) _, latents, sigmas, timesteps = self.prepare_latents( latents=noise, - input_video=None, + input_image=None, denoising_strength=None, num_inference_steps=num_inference_steps, ) @@ -338,7 +341,7 @@ def text2audio( self.load_models_to_device(["dit"]) hide_progress = dist.is_initialized() and dist.get_rank() != 0 for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): - timestep = timestep.to(dtype=self.dtype, device=self.device) + timestep = timestep[None].to(dtype=self.dtype, device=self.device) # Classifier-free guidance if cfg_start_step <= i < cfg_end_step: noise_pred = self.predict_noise_with_cfg( From ae38b57c1a99d14669f194aa7bd1595f06134613 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Fri, 12 Sep 2025 17:47:43 +0800 Subject: [PATCH 6/8] runnable, but output is shit --- diffsynth_engine/models/ace_step/ace_dit.py | 13 +++++-------- diffsynth_engine/pipelines/ace_step.py | 6 +++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index d9a91bcd..09a40926 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -17,7 +17,8 @@ class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"): - super().__init__() + super().__init__() # TODO: how to deal with meta device issue? + device = "cuda:2" self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)) self._set_cos_sin_cache(seq_len=max_position_embeddings) @@ -25,19 +26,14 @@ def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() + self.freqs_cis_cached = torch.polar(torch.ones_like(freqs), freqs) def forward(self, x: torch.Tensor): seq_len = x.shape[1] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len) - return ( - self.cos_cached[:seq_len].to(x.dtype), - self.sin_cached[:seq_len].to(x.dtype), - ) + return self.freqs_cis_cached[:seq_len][None, :, None, :].to(x.device) class SelfAttention(nn.Module): @@ -365,6 +361,7 @@ def __init__( dim=head_dim, max_position_embeddings=max_position, base=rope_theta, + device=device, ) inner_dim = num_heads * head_dim diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 91730405..4a8de838 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -228,7 +228,7 @@ def predict_noise_with_cfg( ): if cfg_scale <= 1.0: return self.predict_noise( - model=model, + model=model.decode, latents=latents, timestep=timestep, context=context, @@ -237,7 +237,7 @@ def predict_noise_with_cfg( ) # cfg by predict noise one by one positive_noise_pred = self.predict_noise( - model=model, + model=model.decode, latents=latents, timestep=timestep, context=context, @@ -246,7 +246,7 @@ def predict_noise_with_cfg( ) negative_noise_pred = fwd_with_temperature( inputs={ - "model": model, + "model": model.decode, "latents": latents, "timestep": timestep, "context": context_null, From 90e7e0a703ff2d35be936d9a76b824c532210497 Mon Sep 17 00:00:00 2001 From: continue revolution Date: Sun, 14 Sep 2025 04:52:30 -0500 Subject: [PATCH 7/8] Fix majorority of problems (#166) * ace step can generate normal stuff, but code is shitty. need massive change * revert t5, rand generator, etc * revert more * rope problem partially resolve * I think there should be no problem --- diffsynth_engine/models/ace_step/ace_dit.py | 39 +++++++++++++-------- diffsynth_engine/pipelines/ace_step.py | 15 +++++--- diffsynth_engine/pipelines/base.py | 2 +- examples/ace_text_to_music.py | 18 +++++++--- pyproject.toml | 2 +- 5 files changed, 51 insertions(+), 25 deletions(-) diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 09a40926..c319e4e7 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -10,30 +10,42 @@ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings from diffsynth_engine.models.basic.attention import attention from diffsynth_engine.models.basic.transformer_helper import RMSNorm -from diffsynth_engine.models.wan.wan_dit import rope_apply, modulate +from diffsynth_engine.models.wan.wan_dit import modulate from diffsynth_engine.models.ace_step.ace_lyric_encoder import ConformerEncoder from diffsynth_engine.utils.constants import ACE_DIT_CONFIG_FILE -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"): - super().__init__() # TODO: how to deal with meta device issue? - device = "cuda:2" - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)) +def rope_apply(x, freqs): # TODO: edit this into complex calculation + cos, sin = freqs # [1, S, 1, D] + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return x_out.to(x.dtype).flatten(3) + + +class Qwen2RotaryEmbedding: + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self._set_cos_sin_cache(seq_len=max_position_embeddings) def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float() + t = torch.arange(self.max_seq_len_cached, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) - self.freqs_cis_cached = torch.polar(torch.ones_like(freqs), freqs) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() - def forward(self, x: torch.Tensor): + def __call__(self, x: torch.Tensor): seq_len = x.shape[1] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len) - return self.freqs_cis_cached[:seq_len][None, :, None, :].to(x.device) + return ( + self.cos_cached[None, :seq_len, None, :].to(x), + self.sin_cached[None, :seq_len, None, :].to(x), + ) class SelfAttention(nn.Module): @@ -175,7 +187,7 @@ def __init__( kernel_size=3, groups=hidden_features * 2, use_bias=True, - act="silu", + act=None, device=device, dtype=dtype, ) @@ -194,7 +206,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x, gate = torch.chunk(x, 2, dim=1) - x *= self.glu_act(gate) + x = x * self.glu_act(gate) x = self.point_conv(x) x = x.transpose(1, 2) return x @@ -285,7 +297,7 @@ def forward(self, x, t): class ACEStepDiTStateDictConverter(StateDictConverter): - def convert(self, state_dict): + def convert(self, state_dict): # TODO: can this be more elegant? for key in list(state_dict.keys()): # change all linear_q / linear_k / linear_v / linear_p to q / k / v / p if "linear_q" in key: @@ -361,7 +373,6 @@ def __init__( dim=head_dim, max_position_embeddings=max_position, base=rope_theta, - device=device, ) inner_dim = num_heads * head_dim diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 4a8de838..9e94f708 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -1,5 +1,6 @@ from typing import Tuple +import math import torch import torch.nn.functional as F import torch.distributed as dist @@ -51,11 +52,11 @@ def hook(module, input, output): handlers.append(handler) with torch.no_grad(): - prompt_emb = model_fwd_func(**inputs) + output = model_fwd_func(**inputs) for handler in handlers: handler.remove() - return prompt_emb + return output class MomentumBuffer: @@ -124,7 +125,7 @@ def __init__( ) self.config = config # sampler - self.noise_scheduler = RecifitedFlowScheduler(shift=3.0) + self.noise_scheduler = RecifitedFlowScheduler(shift=config.shift) self.sampler = FlowMatchEulerSampler() # models self.lyric_tokenizer = VoiceBpeTokenizer() @@ -300,6 +301,10 @@ def text2audio( guidance_interval: float = 0.5, progress_callback: Optional[Callable[[int, int, str], None]] = None, ): + def logistic(x, L=0.9, U=1.1, x_0=0.0, k=0.1): + return L + (U - L) / (1 + math.exp(-k * (x - x_0))) + omega = logistic(omega_scale) + prompt_emb, prompt_attn_mask = self.encode_prompt(prompt) prompt_emb_null, prompt_attn_mask_null = self.encode_prompt_null(prompt) if len(lyrics.strip()) > 0: @@ -319,7 +324,7 @@ def text2audio( ) # Initialize sampler self.sampler.initialize(sigmas=sigmas) - # guidance interval + # Guidance interval cfg_start_step = int(num_inference_steps * ((1 - guidance_interval) / 2)) cfg_end_step = int(num_inference_steps * (guidance_interval / 2 + 0.5)) momentum_buffer = MomentumBuffer() @@ -369,7 +374,7 @@ def text2audio( dx: torch.Tensor = noise_pred * (self.sampler.sigmas[i + 1] - self.sampler.sigmas[i]) dx_mean = dx.mean(dim=(1, 2, 3), keepdim=True) latents = latents.to(dtype=torch.float32) - latents += (dx - dx_mean) * omega_scale + dx + latents += (dx - dx_mean) * omega + dx_mean latents = latents.to(dtype=noise_pred.dtype) if progress_callback is not None: progress_callback(i + 1, len(timesteps), "DENOISING") diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index 37fbccbf..5c2e2ae3 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -210,7 +210,7 @@ def prepare_latents( sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps) # k-diffusion # if you have any questions about this, please ask @dizhipeng.dzp for more details - latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) + # latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) init_latents = latents.clone() sigmas, timesteps = ( sigmas.to(device=self.device, dtype=self.dtype), diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index f817485a..a6eac329 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -1,4 +1,5 @@ -# import random +import random +import argparse from diffsynth_engine.configs import ACEStepPipelineConfig from diffsynth_engine.pipelines.ace_step import ACEStepMusicPipeline @@ -7,21 +8,30 @@ if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--seed", + type=int, + default=random.randint(0, 2**32 - 1), + ) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + config = ACEStepPipelineConfig( model_path=fetch_model( model_uri="ACE-Step/ACE-Step-v1-3.5B", path="ace_step_transformer/diffusion_pytorch_model.safetensors", ), - device="cuda:2", + device=args.device, ) - seed = 3299954530 pipe = ACEStepMusicPipeline.from_pretrained(config) audio = pipe.text2audio( prompt="pop, rap, electronic, blues, hip-house, rhythm and blues", lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意", audio_duration=170.63997916666668, + seed=args.seed, ) - save_audio(audio, f"tmp/ace_t2m_{seed}") + save_audio(audio, f"tmp/ace_t2m_{args.seed}") del pipe diff --git a/pyproject.toml b/pyproject.toml index 8b4c1197..2c5dde92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "moviepy", "librosa", "scikit-image", - "trimesh" + "trimesh", "py3langid", "pypinyin", "hangul_romanize", From bff89eb8b8db5e850653275134f733561b6a84c0 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Mon, 15 Sep 2025 11:00:59 +0800 Subject: [PATCH 8/8] indentation --- diffsynth_engine/pipelines/ace_step.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 9e94f708..b0f030ad 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -302,7 +302,7 @@ def text2audio( progress_callback: Optional[Callable[[int, int, str], None]] = None, ): def logistic(x, L=0.9, U=1.1, x_0=0.0, k=0.1): - return L + (U - L) / (1 + math.exp(-k * (x - x_0))) + return L + (U - L) / (1 + math.exp(-k * (x - x_0))) omega = logistic(omega_scale) prompt_emb, prompt_attn_mask = self.encode_prompt(prompt)