diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 308321717c..03fe5c0418 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -605,6 +605,15 @@ from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( RoformerV2Tokenizer as RoformerV2Tokenizer, ) +from keras_hub.src.models.rwkv7.rwkv7_backbone import ( + RWKV7Backbone as RWKV7Backbone, +) +from keras_hub.src.models.rwkv7.rwkv7_causal_lm import ( + RWKV7CausalLM as RWKV7CausalLM, +) +from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import ( + RWKV7CausalLMPreprocessor as RWKV7CausalLMPreprocessor, +) from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone from keras_hub.src.models.sam.sam_image_segmenter import ( SAMImageSegmenter as SAMImageSegmenter, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index b155d0e6e1..264bc8bdd4 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -90,6 +90,9 @@ from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( RoformerV2Tokenizer as RoformerV2Tokenizer, ) +from keras_hub.src.models.rwkv7.rwkv7_tokenizer import ( + RWKVTokenizer as RWKVTokenizer, +) from keras_hub.src.models.siglip.siglip_tokenizer import ( SigLIPTokenizer as SigLIPTokenizer, ) diff --git a/keras_hub/src/models/rwkv7/rwkv7_backbone.py b/keras_hub/src/models/rwkv7/rwkv7_backbone.py new file mode 100644 index 0000000000..de460d95e9 --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_backbone.py @@ -0,0 +1,185 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block + + +def rwkv7_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.RWKV7Backbone") +class RWKV7Backbone(Backbone): + """The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture. + + This network implements a Modern RNN architecture based on linear + attention mechanisms with recurrent processing, as described in the + RWKV papers. It includes the embedding lookups and RWKV-7 blocks. + + The default constructor gives a fully customizable, randomly initialized + RWKV-7 model with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + hidden_size: int. The size of the transformer encoding and pooling + layers. + head_size: int. The size of each attention head. + num_layers: int. The number of transformer layers. + vocabulary_size: int. The size of the token vocabulary. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + gate_lora: int. LoRA dimension for gating. + mv_lora: int. LoRA dimension for value mixing. + aaa_lora: int. LoRA dimension for alpha parameters. + decay_lora: int. LoRA dimension for decay parameters. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + dropout_rate: float. Dropout rate for the dropout layer. + + Examples: + + ```python + input_data = np.ones(shape=(1, 12), dtype="int32") + + + # Randomly initialized RWKV-7 decoder with custom config. + model = keras_hub.models.RWKV7Backbone( + vocabulary_size=10, + hidden_size=512, + num_layers=2, + head_size=64, + intermediate_dim=1024, + dtype="float32" + ) + model(input_data) + ``` + """ + + def __init__( + self, + hidden_size, + head_size, + num_layers, + vocabulary_size, + intermediate_dim, + gate_lora=128, + mv_lora=32, + aaa_lora=64, + decay_lora=64, + dtype=None, + dropout_rate=0, + **kwargs, + ): + """Initialize RWKV7 backbone. + + Args: + hidden_size: Hidden dimension size. + head_size: Attention head size. + num_layers: Number of RWKV blocks. + vocabulary_size: Size of vocabulary. + intermediate_dim: Intermediate dimension for FFN. + gate_lora: LoRA dimension for gating. + mv_lora: LoRA dimension for value mixing. + aaa_lora: LoRA dimension for alpha parameters. + decay_lora: LoRA dimension for decay parameters. + dtype: Data type for the layer. + dropout_rate: Dropout rate for regularization. + **kwargs: Additional arguments. + """ + # === Layers === + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=hidden_size, + embeddings_initializer=rwkv7_kernel_initializer(), + dtype=dtype, + name="token_embedding", + ) + self.token_embedding.build([None, None]) + + self.output_layer_norm = keras.layers.LayerNormalization( + epsilon=1e-5, name="output_norm" + ) + self.output_layer_norm.build([None, None, hidden_size]) + self.dropout = keras.layers.Dropout( + dropout_rate, + dtype=dtype, + name="dropout", + ) + self.rwkv_layers = [] + for i in range(num_layers): + layer = RWKV7_Block( + hidden_size, + head_size, + intermediate_dim, + gate_lora, + mv_lora, + aaa_lora, + decay_lora, + use_initial_norm=i == 0, + kernel_initializer=rwkv7_kernel_initializer(), + dtype=dtype, + name=f"rwkv_layer_{i}", + ) + + self.rwkv_layers.append(layer) + self.head = keras.layers.Dense( + units=vocabulary_size, + kernel_initializer=rwkv7_kernel_initializer(), + use_bias=False, + name="head", + ) + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + + padding_mask = ops.not_equal(token_id_input, 0) + + x = self.token_embedding(token_id_input) + padding_mask = ops.cast(padding_mask, dtype=x.dtype) + v_first = None + for rwkv_layer in self.rwkv_layers: + x, v_first = rwkv_layer(x, v_first, padding_mask) + x = self.dropout(x) + sequence_output = self.output_layer_norm(x) + sequence_output = self.head(sequence_output) + super().__init__( + inputs=token_id_input, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + # Initialize the graph to avoid potential errors in some cases + self.call(ops.ones([1, 16], "int32")) + + self.num_layers = num_layers + self.head_size = head_size + self.hidden_size = hidden_size + self.gate_lora = gate_lora + self.mv_lora = mv_lora + self.aaa_lora = aaa_lora + self.decay_lora = decay_lora + self.vocabulary_size = vocabulary_size + self.dropout_rate = dropout_rate + self.intermediate_dim = intermediate_dim + + def get_config(self): + config = { + "hidden_size": self.hidden_size, + "head_size": self.head_size, + "gate_lora": self.gate_lora, + "mv_lora": self.mv_lora, + "aaa_lora": self.aaa_lora, + "decay_lora": self.decay_lora, + "vocabulary_size": self.vocabulary_size, + "dropout_rate": self.dropout_rate, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_hub/src/models/rwkv7/rwkv7_backbone_test.py b/keras_hub/src/models/rwkv7/rwkv7_backbone_test.py new file mode 100644 index 0000000000..e061c0e3e6 --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_backbone_test.py @@ -0,0 +1,37 @@ +from keras import ops + +from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone +from keras_hub.src.tests.test_case import TestCase + + +class RWKV7BackboneTest(TestCase): + def setUp(self): + """ + Set up the test case with default arguments and input data. + """ + self.init_kwargs = { + "vocabulary_size": 10, + "hidden_size": 16, + "num_layers": 2, + "head_size": 4, + "intermediate_dim": 32, + "gate_lora": 32, + "mv_lora": 16, + "aaa_lora": 16, + "decay_lora": 16, + } + self.input_data = ops.ones((2, 5), dtype="int32") + self.backbone = RWKV7Backbone(**self.init_kwargs) + + def test_backbone_basics(self): + """ + Test basic functionality of the RWKV7 backbone. + """ + y = self.backbone(self.input_data) + self.assertEqual(y.shape, (2, 5, 10)) + + def test_num_parameters(self): + """ + Test that the model has the expected number of parameters. + """ + self.assertEqual(self.backbone.count_params(), 10208) diff --git a/keras_hub/src/models/rwkv7/rwkv7_causal_lm.py b/keras_hub/src/models/rwkv7/rwkv7_causal_lm.py new file mode 100644 index 0000000000..c7a33c1fb9 --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_causal_lm.py @@ -0,0 +1,252 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone +from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import ( + RWKV7CausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.RWKV7CausalLM") +class RWKV7CausalLM(CausalLM): + """An end-to-end RWKV-7 model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a RWKV-7 model, simply by calling `fit()`. + + This model has a generate() method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + sampler argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + Args: + backbone: A `keras_hub.models.RWKV7Backbone` instance. + preprocessor: A `keras_hub.models.RWKV7CausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + ```python + # Initialize the tokenizer and load assets from a local path. + tokenizer = RWKVTokenizer() + tokenizer.load_assets(rwkv_path) + + # Create a preprocessor with a sequence length of 8. + preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8) + + # Initialize the model with a backbone and preprocessor. + causal_lm = RWKV7CausalLM(backbone, preprocessor) + + prompts = ["Bubble sort\n```python", "Hello World\n```python\n"] + + causal_lm.compile(sampler="greedy") + + outputs = causal_lm.generate(prompts, max_length=128) + for out in outputs: + print(out) + print("-" * 100) + ``` + """ + + backbone_cls = RWKV7Backbone + preprocessor_cls = RWKV7CausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + """Initialize the RWKV-7 causal language model. + + Args: + backbone: The backbone model. + preprocessor: The preprocessor for tokenization. + **kwargs: Additional keyword arguments. + """ + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + super().__init__( + inputs=backbone.inputs, + outputs=backbone.outputs, + **kwargs, + ) + self.call(ops.ones([1, 16], "int32")) + + def call_with_cache( + self, + token_ids, + cache, + compute_head=True, + padding_mask=None, + rnn_mode=True, + ): + """Forward pass of `RWKV7CausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous state Tensors in RWKV layers, and avoids + recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of state and token values. + compute_head: bool, whether to compute the output head. + padding_mask: a dense bool Tensor, the padding mask. + rnn_mode: bool, whether to use RNN mode. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + state_cachce, last_token_cache = cache + x = self.backbone.token_embedding(token_ids) + if padding_mask is None: + padding_mask = ops.not_equal(token_ids, 0) + v_first = None + updated_state_cachce = [] + updated_last_token_cache = [] + + for i in range(self.backbone.num_layers): + current_state_cache = state_cachce[:, i, ...] + current_token_cache = last_token_cache[:, i, ...] + x, v_first, new_cache_state, cache_tmix_x, cache_cmix_x = ( + self.backbone.rwkv_layers[i].call( + x, + v_first=v_first, + padding_mask=padding_mask, + cache_state=current_state_cache, + cache_tmix_x=current_token_cache[:, 0], + cache_cmix_x=current_token_cache[:, 1], + rnn_mode=rnn_mode, + train_mode=False, + ) + ) + new_token_cache = ops.stack([cache_tmix_x, cache_cmix_x], axis=1) + updated_state_cachce.append(new_cache_state) + updated_last_token_cache.append(new_token_cache) + cache = [ + ops.stack(updated_state_cachce, axis=1), + ops.stack(updated_last_token_cache, axis=1), + ] + hidden_states = x = self.backbone.output_layer_norm(x) + if compute_head: + logits = self.backbone.head(x) + else: + logits = None + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + num_layers = self.backbone.num_layers + head_dim = self.backbone.head_size + hidden_size = self.backbone.hidden_size + num_heads = hidden_size // head_dim + + state_cachce = ops.zeros( + [batch_size, num_layers, num_heads, head_dim, head_dim], + dtype=self.compute_dtype, + ) + last_token_cache = ops.zeros( + [batch_size, num_layers, 2, 1, hidden_size], + dtype=self.compute_dtype, + ) + cache = [state_cachce, last_token_cache] + + # Seed the cache. + # Prefill stage can use kernel for better performance + _, hidden_states, cache = self.call_with_cache( + token_ids, + cache, + rnn_mode=False, + compute_head=False, + ) + + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with keys `"token_ids"`, `"padding_mask"`, and + `"predict_token_ids"` with batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask, predict_token_ids = ( + inputs["token_ids"], + inputs["padding_mask"], + inputs["predict_token_ids"], + ) + # Create and seed cache with a single forward pass. + + hidden_states, cache = self._build_cache(token_ids) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + output_ids = self.sampler( + next=next, + prompt=predict_token_ids, + cache=cache, + index=1, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + padding_mask = ops.concatenate( + [ + ops.cast(ops.not_equal(token_ids, 0), padding_mask.dtype), + padding_mask, + ], + axis=1, + ) + token_ids = ops.concatenate([token_ids, output_ids], axis=1) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py b/keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py new file mode 100644 index 0000000000..0071cda60b --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py @@ -0,0 +1,241 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone +from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer + + +@keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor") +class RWKV7CausalLMPreprocessor(CausalLMPreprocessor): + """RWKV-7 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.RWKV7CausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.RWKV7CausalLM` instance, these methods + will be called implicitly in generate(). They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.RWKVTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured sequence_length of + the layer. + + + Examples: + ```python + # Initialize the tokenizer and load assets from a local path. + tokenizer = RWKVTokenizer() + tokenizer.load_assets(rwkv_path) + + # Create a preprocessor with a sequence length of 8. + preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8) + + # Tokenize and pack a batch of sentences. + preprocessor(["Bubble sort\n```python", "Hello World\n```python\n"]) + + # Preprocess inputs for generation with a maximum generation length of 16. + preprocessor.generate_preprocess( + ["Bubble sort\n```python", "Hello World\n```python\n"], 16 + ) + ``` + Outputs (torch Backend) : + tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 893, + 1760, 2011, 32082, 11, 6884], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 33155, 37576, 11, 6884, 42114]], dtype=torch.int32), + tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 893, 1760, + 2011, 32082, 11, 6884, 42114], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 33155, + 37576, 11, 6884, 42114, 11]], dtype=torch.int32), + tensor([[False, False, False, False, False, False, False, False, True, + True, True, True, True, True, True], + [False, False, False, False, False, False, False, False, False, + True, True, True, True, True, True]]) + + {'token_ids': tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 893, 1760, 2011, 32082, 11, 6884], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 33155, 37576, 11, 6884, 42114]], dtype=torch.int32), + 'padding_mask': tensor([[ True, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False], + [True, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + False, False, False, False, False]]), + 'predict_token_ids': tensor([[42114, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0], + [ 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0]], dtype=torch.int32)} + """ + + backbone_cls = RWKV7Backbone + tokenizer_cls = RWKVTokenizer + + def __init__( + self, + tokenizer, + add_start_token=False, + **kwargs, + ): + """Initialize the preprocessor. + + Args: + tokenizer: The tokenizer to use. + add_start_token: Whether to add start token. + **kwargs: Additional arguments. + """ + super().__init__( + tokenizer=tokenizer, add_start_token=add_start_token, **kwargs + ) + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + """Preprocess the input for training. + + Args: + x: Input text data. + y: Target data (optional). + sample_weight: Sample weights (optional). + sequence_length: Desired sequence length. + + Returns: + Preprocessed data tuple (x, y, sample_weight). + """ + if isinstance(x, str): + x = [x] + sequence_length = sequence_length or self.sequence_length + # Pad length to multiples of 16 to meet kernel requirements + if sequence_length is None: + raise (ValueError("`sequence_length` must be specified.")) + if (sequence_length - 1) % 16 != 0: + sequence_length = sequence_length + ( + 16 - (sequence_length - 1) % 16 + ) + x = self.tokenizer(x) + + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + + # The last token does not have a next token, so we truncate it out. + x = token_ids[..., :-1] + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + + def build(self, input_shape): + self.packer = StartEndPacker( + start_value=None, + end_value=None, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + padding_side="left", # RWKV uses left-padding exclusively + ) + self.built = True + + def generate_preprocess( + self, + x, + sequence_length, + ): + """Preprocess input for generation. + + Args: + x: Input text data. + sequence_length: Maximum generation length. + + Returns: + Dictionary with preprocessed inputs for generation. + """ + if isinstance(x, str): + x = [x] + + if not self.built: + self.build(None) + # Align with Keras API + # Input sequence_length is the maximum generation length + # While self.sequence_length corresponds to the prefill max length + generate_length = sequence_length + if sequence_length is None: + raise (ValueError("`sequence_length` must be specified.")) + sequence_length = self.sequence_length + + # Pad length to multiples of 16 to meet kernel requirements + if sequence_length % 16 != 0: + sequence_length = sequence_length + (16 - sequence_length % 16) + if generate_length % 16 != 0: + generate_length = generate_length + (16 - generate_length % 16) + + x = [t[-sequence_length:] for t in self.tokenizer(x)] + y = ops.zeros((len(x), generate_length), "int32") + # Utilize RNN characteristics where prefill and decode are two sequences + # But the first token of decode should be the last token of prefill + start_token = [[t[-1]] for t in x] + x = [t[:-1] if len(t) > 1 else [0] for t in x] + + token_ids, __ = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + start_token = ops.convert_to_tensor(start_token, "int32") + y = ops.slice_update(y, [0, 0], start_token) + padding_mask = ops.not_equal(y, 0) + + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + "predict_token_ids": y, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + + Args: + x: Dictionary containing token_ids and padding_mask. + + Returns: + Detokenized string output. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + padding_mask = ops.convert_to_numpy(padding_mask) + return self.tokenizer.detokenize(token_ids * padding_mask) diff --git a/keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor_test.py b/keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..a2648b9c4a --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor_test.py @@ -0,0 +1,98 @@ +import numpy as np + +from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import ( + RWKV7CausalLMPreprocessor, +) +from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class RWKV7CausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = RWKVTokenizer( + ["1 ' ' 1", "2 '\\n' 1", "3 'the' 3", "4 'hello' 5", "5 'world' 5"] + ) + self.preprocessor = RWKV7CausalLMPreprocessor( + tokenizer=self.tokenizer, + sequence_length=15, + ) + + def test_preprocessor_basics(self): + result = self.preprocessor(x=["hello world hello world hello world"]) + self.assertAllEqual( + result[0], [[0, 0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1]] + ) + self.assertAllEqual( + result[1], [[0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1, 5]] + ) + self.assertAllEqual( + result[2], + [ + [ + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ] + ], + ) + + def test_generate_preprocess(self): + result = self.preprocessor.generate_preprocess( + ["hello world hello world hello world"], 16 + ) + self.assertAllEqual( + result["token_ids"], + [[0, 0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1]], + ) + self.assertAllEqual( + result["padding_mask"], + [ + [ + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ] + ], + ) + self.assertAllEqual( + result["predict_token_ids"], + [[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + ) + + def test_generate_postprocess(self): + input_data = { + "token_ids": np.array( + [[3, 2, 4, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ), + "padding_mask": np.array( + [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + ), + } + result = self.preprocessor.generate_postprocess(input_data) + self.assertEqual(result, ["the\nhellothe"]) diff --git a/keras_hub/src/models/rwkv7/rwkv7_causal_lm_test.py b/keras_hub/src/models/rwkv7/rwkv7_causal_lm_test.py new file mode 100644 index 0000000000..215fda095d --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_causal_lm_test.py @@ -0,0 +1,90 @@ +from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone +from keras_hub.src.models.rwkv7.rwkv7_causal_lm import RWKV7CausalLM +from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import ( + RWKV7CausalLMPreprocessor, +) +from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class RWKV7CausalLMTest(TestCase): + def setUp(self): + """ + Set up the test case with vocabulary, merges, preprocessor, backbone, + and other initialization parameters. + """ + # Create a small vocabulary for testing + self.vocab = [ + "0 ' ' 1", + "1 '\\n' 1", + "2 'the' 3", + "3 'hello' 5", + "4 'world' 5", + "5 'python' 6", + ] + + # Initialize tokenizer with test vocabulary + self.tokenizer = RWKVTokenizer(vocabulary=self.vocab) + + # Create preprocessor with sequence length of 8 + self.preprocessor = RWKV7CausalLMPreprocessor( + tokenizer=self.tokenizer, + sequence_length=16, + ) + + # Create a small backbone for testing + self.backbone = RWKV7Backbone( + vocabulary_size=5, + hidden_size=16, + num_layers=2, + head_size=4, + intermediate_dim=32, + gate_lora=8, + mv_lora=4, + aaa_lora=4, + decay_lora=4, + ) + + # Initialize parameters for the causal LM + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + + def test_generate(self): + """ + Test text generation functionality. + """ + causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) + prompt = ["hello world"] + output = causal_lm.generate(prompt, 16) + self.assertTrue(isinstance(output[0], str)) + self.assertTrue(isinstance(output, list)) + + prompt = "hello world" + output = causal_lm.generate(prompt, 16) + self.assertTrue(isinstance(output, str)) + + def test_generate_strip_prompt(self): + """ + Test that generated text can strip the prompt from output. + """ + prompt = ["hello world"] + causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) + output = causal_lm.generate(prompt, 16, strip_prompt=True) + self.assertFalse(output[0].startswith(prompt[0])) + + def test_generate_compilation(self): + """ + Test that the generate function compiles correctly and + reuses compiled functions. + """ + causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) + causal_lm.generate(["hello world"], 16) + first_fn = causal_lm.generate_function + causal_lm.generate(["hello world"], 16) + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) diff --git a/keras_hub/src/models/rwkv7/rwkv7_layer.py b/keras_hub/src/models/rwkv7/rwkv7_layer.py new file mode 100644 index 0000000000..309767eb80 --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_layer.py @@ -0,0 +1,676 @@ +import warnings + +import keras +from keras import initializers +from keras import ops +from keras.layers import Layer + + +def transpose_head(x, head_first): + x = ops.cast(x, dtype="float32") + if head_first: + return ops.transpose(x, (0, 2, 1, 3)) + else: + return x + + +def rnn_generalized_delta_rule( + r, + w, + k, + v, + a, + b, + initial_state=None, + output_final_state: bool = True, + head_first: bool = False, +): + """Implements the generalized delta rule for RWKV.""" + DTYPE = r.dtype + B, T, H, N = ops.shape(r) + r = transpose_head(r, head_first) + + k = transpose_head(k, head_first) + + v = transpose_head(v, head_first) + a = transpose_head(a, head_first) + b = transpose_head(b, head_first) + w = transpose_head(w, head_first) + w = ops.exp(-ops.exp(w)) + + if initial_state is not None: + state = initial_state + if ops.shape(state)[0] == 1: + state = ops.broadcast_to(state, (B, H, N, N)) + else: + state = ops.zeros((B, H, N, N)) + state = ops.cast(state, "float32") + out = ops.zeros((B, T, H, N), DTYPE) + + def step(t, inputs): + state, out = inputs + kk = ops.reshape(k[:, t, :], (B, H, 1, N)) + rr = ops.reshape(r[:, t, :], (B, H, N, 1)) + vv = ops.reshape(v[:, t, :], (B, H, N, 1)) + aa = ops.reshape(a[:, t, :], (B, H, N, 1)) + bb = ops.reshape(b[:, t, :], (B, H, 1, N)) + state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk + o = ops.cast((state @ rr), out.dtype) + out = ops.slice_update(out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N))) + return [state, out] + + state, out = ops.fori_loop(0, T, step, [state, out]) + + if output_final_state: + return ops.cast(out, DTYPE), state + return ops.cast(out, DTYPE) + + +class TimeShift(Layer): + """Time shift layer that shifts input sequence by one step. + It also be called short conv + """ + + def __init__(self, name="time_shift"): + super(TimeShift, self).__init__(name=name) + + def call(self, inputs, cache_x=None): + if cache_x is not None: + x = ops.concatenate([cache_x, inputs], axis=1) + else: + x = ops.pad(inputs, [[0, 0], [1, 0], [0, 0]], constant_values=0.0) + return x[:, :-1, :] + + def compute_output_shape(self, input_shape): + return input_shape + + +class RWKV7_ChannelMix(Layer): + """RWKV-7 channel mixing layer.""" + + def __init__(self, dim_ffn, kernel_initializer="glorot_uniform", **kwargs): + """Initialize RWKV7 channel mixer. + + Args: + dim_ffn: Feed-forward dimension. + kernel_initializer: Weight initializer. + **kwargs: Additional layer arguments. + """ + super().__init__(**kwargs) + self.dim_ffn = dim_ffn + self.kernel_initializer = initializers.get(kernel_initializer) + + def call(self, x, last_cache_x=None, train_mode=True): + """Process input through channel mixer. + + Args: + x: Input tensor. + last_cache_x: Cached previous values. + train_mode: Whether in training mode. + + Returns: + Mixed output tensor. + """ + xx = self.time_shift(x, last_cache_x) - x + if last_cache_x is not None or not train_mode: + last_cache_x = x[:, -1:] + k = x + xx * self.x_k + k = ops.relu(self.key(k)) ** 2 + output = self.value(k) + if train_mode: + return output + return output, last_cache_x + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + return input_shape[0] + return input_shape + + def build(self, input_shape): + super().build(input_shape) + if isinstance(input_shape, list): + input_shape = input_shape[0] + self.x_k = self.add_weight( + shape=(1, 1, input_shape[-1]), + name="time_mix_k", + initializer=self.kernel_initializer, + ) + self.time_shift = TimeShift() + self.key = keras.layers.Dense( + self.dim_ffn, + use_bias=False, + name="dense_k", + kernel_initializer=self.kernel_initializer, + ) + self.value = keras.layers.Dense( + input_shape[-1], + use_bias=False, + name="dense_v", + kernel_initializer=self.kernel_initializer, + ) + self.key.build(input_shape) + self.value.build([None, None, self.dim_ffn]) + + def get_config(self): + config = { + "dim_ffn": self.dim_ffn, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class GroupNorm(keras.layers.GroupNormalization): + """Group normalization with backend-specific handling. + + Extends Keras GroupNormalization with PyTorch backend support. + """ + + def call(self, inputs): + if keras.config.backend() == "torch": + import torch.nn.functional as F + + return F.group_norm( + inputs, self.groups, self.gamma, self.beta, self.epsilon + ) + return super().call(inputs) + + +class RWKV7_TimeMix(Layer): + """RWKV-7 time mixing layer.""" + + def __init__( + self, + hidden_size, + head_size, + gate_lora=128, + mv_lora=32, + aaa_lora=64, + decay_lora=64, + kernel_initializer="glorot_uniform", + **kwargs, + ): + """Initialize RWKV7 time mixer. + + Args: + hidden_size: Hidden dimension size. + head_size: Attention head size. + gate_lora: LoRA dimension for gating. + mv_lora: LoRA dimension for value mixing. + aaa_lora: LoRA dimension for alpha parameters. + decay_lora: LoRA dimension for decay parameters. + kernel_initializer: Weight initializer. + **kwargs: Additional layer arguments. + """ + super().__init__(**kwargs) + self.head_size = head_size + self.hidden_size = hidden_size + self.n_head = hidden_size // self.head_size + self.gate_lora = gate_lora + self.mv_lora = mv_lora + self.aaa_lora = aaa_lora + self.decay_lora = decay_lora + self.kernel_initializer = initializers.get(kernel_initializer) + self.initial_state = None + try: + from rwkv_ops import generalized_delta_rule + + self.RWKV7_OP = generalized_delta_rule + except ImportError: + warnings.warn( + "The 'rwkv_ops' package is not installed. " + "Falling back to the default (pure-Python) operators" + "pure-Python which will be very slow. " + "Please 'pip install rwkv_ops' to enable the optimized kernels", + UserWarning, + stacklevel=2, + ) + self.RWKV7_OP = rnn_generalized_delta_rule + + assert self.hidden_size % self.n_head == 0 + + def build(self, input_shape): + super().build(input_shape) + if isinstance(input_shape[0], list): + input_shape = input_shape[0] + H = self.n_head + N = self.head_size + B, T, C = input_shape + + self.x_r = self.add_weight( + shape=(1, 1, C), name="x_r", initializer=self.kernel_initializer + ) + self.x_w = self.add_weight( + shape=(1, 1, C), name="x_w", initializer=self.kernel_initializer + ) + self.x_k = self.add_weight( + shape=(1, 1, C), name="x_k", initializer=self.kernel_initializer + ) + self.x_v = self.add_weight( + shape=(1, 1, C), name="x_v", initializer=self.kernel_initializer + ) + self.x_a = self.add_weight( + shape=(1, 1, C), name="x_a", initializer=self.kernel_initializer + ) + self.x_g = self.add_weight( + shape=(1, 1, C), name="x_g", initializer=self.kernel_initializer + ) + + self.w0 = self.add_weight( + shape=(1, 1, C), name="w0", initializer=self.kernel_initializer + ) + self.w1 = self.add_weight( + shape=(C, self.decay_lora), + name="w1", + initializer=self.kernel_initializer, + ) + self.w2 = self.add_weight( + shape=(self.decay_lora, C), + name="w2", + initializer=self.kernel_initializer, + ) + + self.a0 = self.add_weight( + shape=(1, 1, C), name="a0", initializer=self.kernel_initializer + ) + self.a1 = self.add_weight( + shape=(C, self.aaa_lora), + name="a1", + initializer=self.kernel_initializer, + ) + self.a2 = self.add_weight( + shape=(self.aaa_lora, C), + name="a2", + initializer=self.kernel_initializer, + ) + + self.v0 = self.add_weight( + shape=(1, 1, C), name="v0", initializer=self.kernel_initializer + ) + self.v1 = self.add_weight( + shape=(C, self.mv_lora), + name="v1", + initializer=self.kernel_initializer, + ) + self.v2 = self.add_weight( + shape=(self.mv_lora, C), + name="v2", + initializer=self.kernel_initializer, + ) + + self.g1 = self.add_weight( + shape=(C, self.gate_lora), + name="g1", + initializer=self.kernel_initializer, + ) + self.g2 = self.add_weight( + shape=(self.gate_lora, C), + name="g2", + initializer=self.kernel_initializer, + ) + + self.k_k = self.add_weight( + shape=(1, 1, C), name="k_k", initializer=self.kernel_initializer + ) + self.k_a = self.add_weight( + shape=(1, 1, C), name="k_a", initializer=self.kernel_initializer + ) + self.r_k = self.add_weight( + shape=(H, N), name="r_k", initializer=self.kernel_initializer + ) + + self.time_shift = TimeShift() + self.receptance = keras.layers.Dense( + C, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="receptance", + ) + self.key = keras.layers.Dense( + C, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="key", + ) + self.value = keras.layers.Dense( + C, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="value", + ) + self.output_layer = keras.layers.Dense( + C, + use_bias=False, + kernel_initializer=self.kernel_initializer, + name="output_layer", + ) + self.ln_x = GroupNorm(groups=H, epsilon=64e-5) + + self.receptance.build(input_shape) + self.value.build(input_shape) + self.key.build(input_shape) + self.output_layer.build(input_shape) + self.ln_x.build((None, C)) + + def call( + self, + x, + v_first=None, + padding_mask=None, + last_cache_x=None, + cache_state=None, + rnn_mode=False, + train_mode=True, + ): + """Process input through time mixer. + + Args: + x: Input tensor. + v_first: First value for mixing. + padding_mask: Mask for padding tokens. + last_cache_x: Cached previous values. + cache_state: Cached recurrent state. + rnn_mode: Whether to use RNN mode. + train_mode: Whether in training mode. + + Returns: + Mixed output tensor and state information. + """ + if cache_state is None: + initial_state = self.initial_state + else: + initial_state = cache_state + if padding_mask is not None: + if ops.ndim(padding_mask) == 2: + padding_mask = padding_mask[..., None] + padding_mask = ops.cast(padding_mask, x.dtype) + x *= padding_mask + B, T, C = ops.shape(x) + H = self.n_head + xx = self.time_shift(x, last_cache_x) - x + if last_cache_x is not None or not train_mode: + last_cache_x = x[:, -1:] + if padding_mask is not None: + xx *= padding_mask + + xr = x + xx * self.x_r + xw = x + xx * self.x_w + xk = x + xx * self.x_k + xv = x + xx * self.x_v + xa = x + xx * self.x_a + xg = x + xx * self.x_g + + r = self.receptance(xr) + w = ( + -ops.softplus( + -( + self.w0 + + ops.matmul(ops.tanh(ops.matmul(xw, self.w1)), self.w2) + ) + ) + - 0.5 + ) # soft-clamp to (-inf, -0.5) + k = self.key(xk) + v = self.value(xv) + if v_first is None: + v_first = v + else: + v = v + (v_first - v) * ops.sigmoid( + self.v0 + ops.matmul(ops.matmul(xv, self.v1), self.v2) + ) + + a = ops.sigmoid( + self.a0 + ops.matmul(ops.matmul(xa, self.a1), self.a2) + ) # a is "in-context learning rate" + g = ops.matmul(ops.sigmoid(ops.matmul(xg, self.g1)), self.g2) + + kk = k * self.k_k + + kk = self.normalize(ops.reshape(kk, (B, T, H, -1))) + kk = ops.reshape(kk, (B, T, C)) + + k = k * (1 + (a - 1) * self.k_a) + if padding_mask is not None: + w = ops.where(padding_mask, w, -1e9) + if rnn_mode: + rwkv7_op = rnn_generalized_delta_rule + else: + rwkv7_op = self.RWKV7_OP + + def reshape_and_cast(x, new_shape, dtype="float32"): + x = ops.reshape(x, new_shape) + if rnn_mode: + return x + return ops.cast(x, dtype) + + x, finnal_state = rwkv7_op( + reshape_and_cast(r, (B, T, self.n_head, self.head_size)), + reshape_and_cast(w, (B, T, self.n_head, self.head_size)), + reshape_and_cast(k, (B, T, self.n_head, self.head_size)), + reshape_and_cast(v, (B, T, self.n_head, self.head_size)), + reshape_and_cast(-kk, (B, T, self.n_head, self.head_size)), + reshape_and_cast(kk * a, (B, T, self.n_head, self.head_size)), + initial_state=ops.cast(initial_state, "float32") + if initial_state is not None + else None, + ) + x = reshape_and_cast(x, (B, T, C), self.compute_dtype) + + x = ops.reshape(self.ln_x(ops.reshape(x, (B * T, C))), ops.shape(x)) + + x = ops.reshape(x, (B, T, C)) + r = ops.reshape(r, (B, T, H, -1)) + k = ops.reshape(k, (B, T, H, -1)) + v = ops.reshape(v, (B, T, C)) + + rwkv = ops.sum(r * k * self.r_k, axis=-1, keepdims=True) * ops.reshape( + v, (B, T, H, -1) + ) + + x = x + ops.reshape(rwkv, (B, T, C)) + x = self.output_layer(x * g) + if train_mode: + return x, v_first + return x, v_first, last_cache_x, finnal_state + + def compute_output_shape(self, input_shape): + output_shapes = [ + [None, None, self.hidden_size], + [None, None, self.hidden_size], + ] + return output_shapes + + def normalize( + self, + x, + eps: float = 1e-12, + ): + # F.normalize like api + if keras.config.backend() == "torch": + import torch.nn.functional as F + + return F.normalize(x, dim=-1, p=2.0) + square_sum = ops.sum(ops.square(x), axis=-1, keepdims=True) + inv_norm = ops.rsqrt(square_sum + eps) + inv_norm = ops.maximum(inv_norm, eps) + return x * inv_norm + + def get_config(self): + config = { + "hidden_size": self.hidden_size, + "head_size": self.head_size, + "gate_lora": self.gate_lora, + "mv_lora": self.mv_lora, + "aaa_lora": self.aaa_lora, + "decay_lora": self.decay_lora, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class RWKV7_Block(Layer): + def __init__( + self, + hidden_size, + head_size, + intermediate_dim, + gate_lora=128, + mv_lora=32, + aaa_lora=64, + decay_lora=64, + use_initial_norm=False, + kernel_initializer="glorot_uniform", + **kwargs, + ): + """Initialize RWKV7 block. + + Args: + hidden_size: Hidden dimension size. + head_size: Attention head size. + intermediate_dim: Intermediate dimension for FFN. + gate_lora: LoRA dimension for gating. + mv_lora: LoRA dimension for value mixing. + aaa_lora: LoRA dimension for alpha parameters. + decay_lora: LoRA dimension for decay parameters. + use_initial_norm: Whether to use initial normalization. + kernel_initializer: Weight initializer. + **kwargs: Additional layer arguments. + """ + super().__init__(**kwargs) + self.head_size = head_size + self.hidden_size = hidden_size + self.gate_lora = gate_lora + self.mv_lora = mv_lora + self.aaa_lora = aaa_lora + self.decay_lora = decay_lora + self.intermediate_dim = intermediate_dim + self.use_initial_norm = use_initial_norm + self.kernel_initializer = initializers.get(kernel_initializer) + + def build(self, input_shape): + super().build(input_shape) + if self.use_initial_norm: + self.ln0 = keras.layers.LayerNormalization( + epsilon=1e-5, name="init_norm" + ) + self.ln0.build(input_shape) + + self.ln1 = keras.layers.LayerNormalization( + epsilon=1e-5, name="att_norm" + ) + self.ln1.build(input_shape) + + self.ln2 = keras.layers.LayerNormalization( + epsilon=1e-5, name="ffn_norm" + ) + self.ln2.build(input_shape) + + self.att = RWKV7_TimeMix( + self.hidden_size, + self.head_size, + self.gate_lora, + self.mv_lora, + self.aaa_lora, + self.decay_lora, + name="RWKV_TIME_MIX", + kernel_initializer=self.kernel_initializer, + ) + self.att.build(input_shape) + + self.ffn = RWKV7_ChannelMix( + self.intermediate_dim, + name="RWKV_CMIX", + kernel_initializer=self.kernel_initializer, + ) + self.ffn.build(input_shape) + + def call( + self, + x, + v_first=None, + padding_mask=None, + cache_state=None, + cache_tmix_x=None, + cache_cmix_x=None, + rnn_mode=False, + train_mode=True, + ): + """Process input through RWKV block. + + Args: + x: Input tensor. + v_first: First value for mixing. + padding_mask: Mask for padding tokens. + cache_state: Cached recurrent state. + cache_tmix_x: Cached time mixer values. + cache_cmix_x: Cached channel mixer values. + rnn_mode: Whether to use RNN mode. + train_mode: Whether in training mode. + + Returns: + Processed output tensor and cache information. + """ + if padding_mask is not None: + padding_mask = ops.cast(padding_mask, x.dtype) + padding_mask = ops.expand_dims(padding_mask, axis=-1) + if self.use_initial_norm: + x = self.ln0(x) + if train_mode: + xx, v_first = self.att( + self.ln1(x), + v_first=v_first, + padding_mask=padding_mask, + train_mode=train_mode, + ) + x = x + xx + xx = self.ln2(x) + if padding_mask is not None: + xx = xx * padding_mask + x = x + self.ffn(xx, train_mode=train_mode) + return x, v_first + else: + xx, v_first, cache_tmix_x, cache_state = self.att.call( + self.ln1(x), + v_first=v_first, + padding_mask=padding_mask, + last_cache_x=cache_tmix_x, + cache_state=cache_state, + rnn_mode=rnn_mode, + train_mode=train_mode, + ) + x = x + xx + xx = self.ln2(x) + if padding_mask is not None: + xx = xx * padding_mask + xx, cache_cmix_x = self.ffn(xx, cache_cmix_x, train_mode=train_mode) + x = x + xx + return x, v_first, cache_state, cache_tmix_x, cache_cmix_x + + def compute_output_shape(self, input_shape): + output_shapes = [ + [None, None, self.hidden_size], + [None, None, self.hidden_size], + ] + return output_shapes + + def get_config(self): + config = { + "hidden_size": self.hidden_size, + "head_size": self.head_size, + "gate_lora": self.gate_lora, + "mv_lora": self.mv_lora, + "aaa_lora": self.aaa_lora, + "decay_lora": self.decay_lora, + "intermediate_dim": self.intermediate_dim, + "use_initial_norm": self.use_initial_norm, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_hub/src/models/rwkv7/rwkv7_tokenizer.py b/keras_hub/src/models/rwkv7/rwkv7_tokenizer.py new file mode 100644 index 0000000000..ef11a059e8 --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_tokenizer.py @@ -0,0 +1,405 @@ +import os + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.tokenizers import tokenizer +from keras_hub.src.utils.tensor_utils import is_int_dtype +from keras_hub.src.utils.tensor_utils import is_string_dtype +from keras_hub.src.utils.tensor_utils import tensor_to_list + +# Vocabulary file name constant +VOCAB_FILENAME = "vocab.txt" + + +class TRIE: + """Byte-level Trie structure for longest prefix matching. + + This class implements a trie data structure that stores byte + sequences and allows efficient longest prefix matching. + """ + + __slots__ = tuple("ch,to,values,front".split(",")) + to: list + values: set + + def __init__(self, front=None, ch=None): + """Initialize a TRIE node. + + Args: + front: Parent node reference. + ch: Byte value for this node. + """ + self.ch = ch + self.to = [None for ch in range(256)] + self.values = set() + self.front = front + + def __repr__(self): + """String representation of the TRIE node.""" + fr = self + ret = [] + while fr is not None: + if fr.ch is not None: + ret.append(fr.ch) + fr = fr.front + return "" % (ret[::-1], self.values) + + def add(self, key: bytes, idx: int = 0, val=None): + """Add a key-value pair to the trie. + + Args: + key: Byte sequence to add. + idx: Current index in key processing. + val: Value to store (defaults to key). + + Returns: + Final node where key was inserted. + """ + if idx == len(key): + if val is None: + val = key + self.values.add(val) + return self + ch = key[idx] + if self.to[ch] is None: + self.to[ch] = TRIE(front=self, ch=ch) + return self.to[ch].add(key, idx=idx + 1, val=val) + + def find_longest(self, key: bytes, idx: int = 0): + """Find longest match in trie for given key. + + Args: + key: Byte sequence to search for. + idx: Starting index for search. + + Returns: + Tuple of (end_index, node, values) for match. + """ + u: TRIE = self + ch: int = key[idx] + + while u.to[ch] is not None: + u = u.to[ch] + idx += 1 + if u.values: + ret = idx, u, u.values + if idx == len(key): + break + ch = key[idx] + return ret + + +class RWKV_TOKENIZER: + """RWKV tokenizer implementation using byte-level trie. + + Implements tokenization using a fixed vocabulary and greedy + longest-match algorithm on byte sequences. + """ + + def __init__(self, vocabs): + """Initialize tokenizer with vocabulary. + + Args: + vocabs: List of vocabulary entries in format + " ". + """ + self.idx2token = {} + sorted = [] # must be already sorted + for l in vocabs: + idx = int(l[: l.index(" ")]) + x = eval(l[l.index(" ") : l.rindex(" ")]) + x = x.encode("utf-8") if isinstance(x, str) else x + assert isinstance(x, bytes) + assert len(x) == int(l[l.rindex(" ") :]) + sorted += [x] + self.idx2token[idx] = x + + self.token2idx = {} + for k, v in self.idx2token.items(): + self.token2idx[v] = int(k) + + self.root = TRIE() + for t, i in self.token2idx.items(): + _ = self.root.add(t, val=(t, i)) + + def encodeBytes(self, src: bytes): + """Encode byte sequence to token IDs. + + Args: + src: Byte sequence to encode. + + Returns: + List of token IDs. + """ + idx: int = 0 + tokens = [] + while idx < len(src): + _idx: int = idx + idx, _, values = self.root.find_longest(src, idx) + assert idx != _idx + _, token = next(iter(values)) + tokens.append(token) + return tokens + + def decodeBytes(self, tokens): + """Decode token IDs to byte sequence. + + Args: + tokens: List of token IDs. + + Returns: + Decoded byte sequence. + """ + return b"".join(map(lambda i: self.idx2token[i], tokens)) + + def encode(self, src): + """Encode text to token IDs. + + Args: + src: Text string or list of strings. + + Returns: + Token IDs or list of token ID lists. + """ + if isinstance(src, str): + return self.encodeBytes(src.encode("utf-8")) + else: + return [self.encodeBytes(s.encode("utf-8")) for s in src] + + def decode(self, tokens): + """Decode token IDs to text. + + Args: + tokens: Token IDs or list of token ID lists. + + Returns: + List of decoded text strings. + """ + return [self.decodeBytes(batch).decode("utf-8") for batch in tokens] + # try: + # return self.decodeBytes(tokens).decode('utf-8') + # except: + # return '\ufffd' # bad utf-8 + + def printTokens(self, tokens): + """Print tokens with their string representations. + + Args: + tokens: List of token IDs to print. + """ + for i in tokens: + s = self.idx2token[i] + try: + s = s.decode("utf-8") + except BaseException: + pass + print(f"{repr(s)}{i}", end=" ") + print() + + +@keras_hub_export("keras_hub.tokenizers.RWKVTokenizer") +class RWKVTokenizer(tokenizer.Tokenizer): + """RWKV byte-level tokenizer with longest-match trie search. + + This tokenizer maps raw text to a sequence of integer token ids + using a fixed vocabulary and a greedy longest-match algorithm. + + Args: + vocabulary: list of strings, each line formatted as + " ". + dtype: output dtype for tensor operations. Must be integer + or string type. + + Examples: + ```python + vocab = ["0 ' ' 1", "1 '\\n' 1", "2 'the' 3", "3 'hello' 5"] + tok = RWKVTokenizer(vocabulary=vocab) + tok("hello the") + ``` + + Output: + [3, 0, 2] + """ + + def __init__( + self, + vocabulary=None, + dtype="int32", + **kwargs, + ) -> None: + """Initialize RWKV tokenizer. + + Args: + vocabulary: Vocabulary list. + dtype: Output data type. + **kwargs: Additional keyword arguments. + """ + if not is_int_dtype(dtype) and not is_string_dtype(dtype): + raise ValueError( + "Output dtype must be an integer type or a string. " + f"Received: dtype={dtype}" + ) + + super().__init__(dtype=dtype, **kwargs) + + self.vocabulary = None + if vocabulary is not None: + self.set_vocabulary(vocabulary) + self.file_assets = [VOCAB_FILENAME] + + def set_vocabulary(self, vocabulary): + """Set the tokenizer vocabulary. + + Args: + vocabulary: Vocabulary list to set. + """ + self.vocabulary = vocabulary + self._tokenizer = RWKV_TOKENIZER(vocabulary) + self.pad_token_id = 0 + self.start_token_id = None + self.end_token_id = self.tokenize(["\n\n"])[0][0] + + def save_assets(self, dir_path): + """Save vocabulary to directory. + + Args: + dir_path: Directory path to save to. + """ + path = os.path.join(dir_path, VOCAB_FILENAME) + with open(path, "wb") as file: + file.write("\n".join(self.vocabulary)) + + def load_assets(self, dir_path=""): + """Load vocabulary from directory. + + Args: + dir_path: Directory path to load from. + """ + path = os.path.join(dir_path, VOCAB_FILENAME) + with open(path, "r", encoding="utf-8") as f: + vocabulary = f.readlines() + self.set_vocabulary(vocabulary) + + def _check_vocabulary(self): + """Check if vocabulary is set, raise error if not.""" + if self.vocabulary is None: + raise ValueError( + "No vocabulary has been set for RWKVTokenizer. Make " + "sure to pass a `vocabulary` argument when creating the layer." + ) + + def vocabulary_size(self): + """Get the size of the vocabulary. + + Returns: + Number of tokens in vocabulary. + """ + self._check_vocabulary() + return int(len(self.vocabulary)) + + def get_vocabulary(self): + """Get the current vocabulary. + + Returns: + Current vocabulary list. + """ + self._check_vocabulary() + return tensor_to_list(self.vocabulary) + + def id_to_token(self, id): + """Convert token ID to string representation. + + Args: + id: Token ID to convert. + + Returns: + String representation of token. + """ + self._check_vocabulary() + if id >= self.vocabulary_size() or id < 0: + raise ValueError( + f"`id` must be in range [0, {self.vocabulary_size() - 1}]. " + f"Received: {id}" + ) + return self._tokenizer.idx2token[id] + + def token_to_id(self, token): + """Convert a string token to an integer id.""" + self._check_vocabulary() + return int(self._tokenizer.token2idx[token]) + + def get_config(self): + """Get tokenizer configuration. + + Returns: + Configuration dictionary. + """ + config = super().get_config() + config.update( + { + "vocabulary": None, # Save vocabulary via an asset! + } + ) + return config + + def tokenize(self, inputs): + """Tokenize input text. + + Args: + inputs: Text to tokenize. + + Returns: + Tokenized representation. + """ + self._check_vocabulary() + tokens = self._tokenizer.encode(inputs) + + def tokens2ids(x): + return [self.id_to_token(t) for t in x] + + if is_string_dtype(self.dtype): + if isinstance(inputs, str): + return tokens2ids(tokens) + return [tokens2ids(t) for t in tokens] + return tokens + + def detokenize(self, inputs): + """Convert tokens back to text. + + Args: + inputs: Tokens to convert. + + Returns: + Detokenized text. + """ + self._check_vocabulary() + strip_zero_inputs = [] + for t in inputs: + strip_zero_inputs.append([x for x in t if x != 0]) + + return self._tokenizer.decode(strip_zero_inputs) + + def compute_output_spec(self, input_spec): + """Compute output specification. + + Args: + input_spec: Input specification. + + Returns: + Output tensor specification. + """ + return keras.KerasTensor( + input_spec.shape + (None,), dtype=self.compute_dtype + ) + + def call(self, inputs): + """Call the tokenizer on inputs. + + Args: + inputs: Input text. + + Returns: + Tokenized output. + """ + return self.tokenize(inputs) diff --git a/keras_hub/src/models/rwkv7/rwkv7_tokenizer_test.py b/keras_hub/src/models/rwkv7/rwkv7_tokenizer_test.py new file mode 100644 index 0000000000..69f76a2366 --- /dev/null +++ b/keras_hub/src/models/rwkv7/rwkv7_tokenizer_test.py @@ -0,0 +1,25 @@ +from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class RWKV7TokenizerTest(TestCase): + def setUp(self): + self.tokenizer = RWKVTokenizer( + ["1 ' ' 1", "2 '\\n' 1", "3 'the' 3", "4 'hello' 5", "5 'world' 5"] + ) + + def test_tokenizer_basics(self): + result = self.tokenizer("hello world") + self.assertAllEqual(result, [4, 1, 5]) + + def test_vocabulary_size(self): + self.assertEqual(self.tokenizer.vocabulary_size(), 5) + + def test_tokenize_and_detokenize(self): + # Test detokenization + text = self.tokenizer.detokenize([[4, 1, 5]]) + self.assertEqual(text[0], "hello world") + + def test_special_tokens(self): + self.assertEqual(self.tokenizer.pad_token_id, 0) + self.assertEqual(self.tokenizer.end_token_id, 2) diff --git a/tools/checkpoint_conversion/convert_rwkv7_checkpoints.py b/tools/checkpoint_conversion/convert_rwkv7_checkpoints.py new file mode 100644 index 0000000000..4504b09253 --- /dev/null +++ b/tools/checkpoint_conversion/convert_rwkv7_checkpoints.py @@ -0,0 +1,468 @@ +# ============================================================================== +# Environment & Dependency Setup +# ============================================================================== +import os + +import numpy as np +import requests +import torch +from absl import app +from absl import flags + +# Force CPU only (GPU index -1 disables CUDA) +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +# Use native kernel implementations +os.environ["KERNEL_TYPE"] = "native" + +# Keras-Ops is imported **after** environment variables are set +import types + +import torch.nn as nn +import torch.nn.functional as F +from keras import ops # noqa: E402 +from modelscope import snapshot_download + +from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone +from keras_hub.src.models.rwkv7.rwkv7_causal_lm import RWKV7CausalLM + +# Local modules +from keras_hub.src.models.rwkv7.rwkv7_tokenizer import RWKVTokenizer + +# ============================================================================== +# Model Preset Registry +# ============================================================================== +PRESET_MAP = { + "RWKV7_G1a_0.1B": "rwkv7-g1a-0.1b-20250728-ctx4096.pth", + "RWKV7_G1a_0.3B": "rwkv7-g1a-0.4b-20250905-ctx4096.pth", + "RWKV7_G1a_1.5B": "rwkv7-g1a-1.5b-20250922-ctx4096.pth", + "RWKV7_G1a_2.9B": "rwkv7-g1a-2.9b-20250924-ctx4096.pth", + "RWKV7_G0a_7.2B": "rwkv7-g0a-7.2b-20250829-ctx4096.pth", +} + +# ============================================================================== +# Command-line Interface +# ============================================================================== +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + +# ============================================================================== +# RWKV-v7 official PyTorch implementation +# From https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo.py +# ============================================================================== +HEAD_SIZE = 64 +D_DECAY_LORA = 64 +D_AAA_LORA = 64 +D_MV_LORA = 32 +D_GATE_LORA = 128 + + +def RWKV7_OP(r, w, k, v, a, b): + """ + Official RWKV-7 core operator. + Performs the time-mix recurrence with delta-rule based learning. + """ + DTYPE = r.dtype + B, T, C = r.size() + H = C // HEAD_SIZE + N = HEAD_SIZE + r = r.view(B, T, H, N).float() + k = k.view(B, T, H, N).float() + v = v.view(B, T, H, N).float() + a = a.view(B, T, H, N).float() + b = b.view(B, T, H, N).float() + + # Compute decay factor (log-space) + w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) + out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) + state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) + + # Recurrent inference loop over time + for t in range(T): + kk = k[:, t, :].view(B, H, 1, N) + rr = r[:, t, :].view(B, H, N, 1) + vv = v[:, t, :].view(B, H, N, 1) + aa = a[:, t, :].view(B, H, N, 1) + bb = b[:, t, :].view(B, H, 1, N) + # State update: decay + delta-rule + residual + state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk + # Read-out for current position + out[:, t, :] = (state @ rr).view(B, H, N) + return out.view(B, T, C).to(DTYPE) + + +# ============================================================================== +# RWKV Time-Mix Layer (Attention) +# ============================================================================== +class RWKV_Tmix_x070(nn.Module): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.head_size = args.head_size_a + self.n_head = args.dim_att // self.head_size + assert args.dim_att % self.n_head == 0 + + H, N, C = self.n_head, self.head_size, args.n_embd + + # Low-rank adaptation & shift scalars + self.x_r = nn.Parameter(torch.empty(1, 1, C)) + self.x_w = nn.Parameter(torch.empty(1, 1, C)) + self.x_k = nn.Parameter(torch.empty(1, 1, C)) + self.x_v = nn.Parameter(torch.empty(1, 1, C)) + self.x_a = nn.Parameter(torch.empty(1, 1, C)) + self.x_g = nn.Parameter(torch.empty(1, 1, C)) + + # Decay (w) modulation + self.w0 = nn.Parameter(torch.empty(1, 1, C)) + self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA)) + self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C)) + + # In-context learning rate (a) modulation + self.a0 = nn.Parameter(torch.empty(1, 1, C)) + self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA)) + self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C)) + + # Value residual modulation + self.v0 = nn.Parameter(torch.empty(1, 1, C)) + self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA)) + self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C)) + + # Gate modulation + self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA)) + self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C)) + + # Normalization & positional factors + self.k_k = nn.Parameter(torch.empty(1, 1, C)) + self.k_a = nn.Parameter(torch.empty(1, 1, C)) + self.r_k = nn.Parameter(torch.empty(H, N)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = nn.Linear(C, C, bias=False) + self.key = nn.Linear(C, C, bias=False) + self.value = nn.Linear(C, C, bias=False) + self.output = nn.Linear(C, C, bias=False) + # GroupNorm with very small epsilon for numerical stability + self.ln_x = nn.GroupNorm(H, C, eps=64e-5) + + # -------------------------------------------------------------------------- + def forward(self, x, v_first=None): + B, T, C = x.size() + H = self.n_head + xx = self.time_shift(x) - x # Difference token shift + + # Apply token-shift to each branch + xr = x + xx * self.x_r + xw = x + xx * self.x_w + xk = x + xx * self.x_k + xv = x + xx * self.x_v + xa = x + xx * self.x_a + xg = x + xx * self.x_g + + r = self.receptance(xr) + w = ( + -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 + ) # Clamp + k = self.key(xk) + v = self.value(xv) + + # Value residual: only active on non-first layers + if self.layer_id == 0: + v_first = v + else: + v = v + (v_first - v) * torch.sigmoid( + self.v0 + (xv @ self.v1) @ self.v2 + ) + + a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # In-context LR + g = torch.sigmoid(xg @ self.g1) @ self.g2 # Gate + + # Normalize keys for stability + kk = k * self.k_k + kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) + k = k * (1 + (a - 1) * self.k_a) + + # Core recurrence + x = RWKV7_OP(r, w, k, v, -kk, kk * a).to(r.dtype) + x = self.ln_x(x.view(B * T, C)).view(B, T, C) + + # Additional local mix (receptance * key * r_k) * value + x = x + ( + (r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum( + dim=-1, keepdim=True + ) + * v.view(B, T, H, -1) + ).view(B, T, C) + x = self.output(x * g) + return x, v_first + + +# ============================================================================== +# RWKV Channel-Mix Layer (Feed-Forward) +# ============================================================================== +class RWKV_CMix_x070(nn.Module): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + with torch.no_grad(): + self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd)) + + self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) + + def forward(self, x): + xx = self.time_shift(x) - x + k = x + xx * self.x_k + k = torch.relu(self.key(k)) ** 2 # Squared ReLU + return self.value(k) + + +# ============================================================================== +# RWKV Building Block (Time-Mix + Channel-Mix + Norms) +# ============================================================================== +class Block(nn.Module): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.ln0 = nn.LayerNorm(args.n_embd) if layer_id == 0 else None + self.ln1 = nn.LayerNorm(args.n_embd) + self.ln2 = nn.LayerNorm(args.n_embd) + + self.att = RWKV_Tmix_x070(args, layer_id) + self.ffn = RWKV_CMix_x070(args, layer_id) + + def forward(self, x, v_first): + if self.layer_id == 0: + x = self.ln0(x) + xx, v_first = self.att(self.ln1(x), v_first) + x = x + xx + x = x + self.ffn(self.ln2(x)) + return x, v_first + + +# ============================================================================== +# Full RWKV Model +# ============================================================================== +class RWKV(nn.Module): + def __init__(self, args): + super().__init__() + args.dim_att = args.n_embd + args.dim_ffn = args.n_embd * 4 + self.emb = nn.Embedding(args.vocab_size, args.n_embd) + + self.blocks = nn.ModuleList( + [Block(args, i) for i in range(args.n_layer)] + ) + self.ln_out = nn.LayerNorm(args.n_embd) + self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) + + def forward(self, idx): + x = self.emb(idx) + v_first = torch.empty_like(x) + for block in self.blocks: + x, v_first = block(x, v_first) + x = self.ln_out(x) + x = self.head(x) + return x + + +# ============================================================================== +# Weight Conversion Utilities (PyTorch ↔ Keras) +# ============================================================================== +def convert_cmix(my_chnnal_mix, weights, i): + my_chnnal_mix.set_weights( + [ + weights.pop("blocks.%d.ffn.x_k" % i), + weights.pop("blocks.%d.ffn.key.weight" % i).T, + weights.pop("blocks.%d.ffn.value.weight" % i).T, + ] + ) + + +def convert_tmix(my_time_mix, weights, i): + weights_list = [ + weights.pop("blocks.%d.att.x_r" % i), + weights.pop("blocks.%d.att.x_w" % i), + weights.pop("blocks.%d.att.x_k" % i), + weights.pop("blocks.%d.att.x_v" % i), + weights.pop("blocks.%d.att.x_a" % i), + weights.pop("blocks.%d.att.x_g" % i), + weights.pop("blocks.%d.att.w0" % i), + weights.pop("blocks.%d.att.w1" % i), + weights.pop("blocks.%d.att.w2" % i), + weights.pop("blocks.%d.att.a0" % i), + weights.pop("blocks.%d.att.a1" % i), + weights.pop("blocks.%d.att.a2" % i), + weights.pop("blocks.%d.att.v0" % i), + weights.pop("blocks.%d.att.v1" % i), + weights.pop("blocks.%d.att.v2" % i), + weights.pop("blocks.%d.att.g1" % i), + weights.pop("blocks.%d.att.g2" % i), + weights.pop("blocks.%d.att.k_k" % i), + weights.pop("blocks.%d.att.k_a" % i), + weights.pop("blocks.%d.att.r_k" % i), + weights.pop("blocks.%d.att.receptance.weight" % i).T, + weights.pop("blocks.%d.att.key.weight" % i).T, + weights.pop("blocks.%d.att.value.weight" % i).T, + weights.pop("blocks.%d.att.output.weight" % i).T, + weights.pop("blocks.%d.att.ln_x.weight" % i), + weights.pop("blocks.%d.att.ln_x.bias" % i), + ] + my_time_mix.set_weights(weights_list) + + +def convert_layernorm(myln, weights, ln_id, layer_id): + myln.set_weights( + [ + weights.pop("blocks.%d.ln%d.weight" % (layer_id, ln_id)), + weights.pop("blocks.%d.ln%d.bias" % (layer_id, ln_id)), + ] + ) + + +def convert_block(my_block, weights, i): + convert_cmix(my_block.ffn, weights, i) + convert_tmix(my_block.att, weights, i) + if my_block.use_initial_norm: + convert_layernorm(my_block.ln0, weights, 0, i) + convert_layernorm(my_block.ln1, weights, 1, i) + convert_layernorm(my_block.ln2, weights, 2, i) + + +def convert_backbone(my_backbone, standard_RWKV): + for i in range(my_backbone.num_layers): + convert_block(my_backbone.rwkv_layers[i], standard_RWKV.blocks[i]) + my_backbone.token_embedding.set_weights( + [standard_RWKV.emb.weight.detach().cpu()] + ) + convert_layernorm(my_backbone.output_layer_norm, standard_RWKV.ln_out) + + +# ============================================================================== +# Checkpoint Conversion Entry Point +# ============================================================================== +def convert_rwkv7_checkpoints(weights_path): + weights = torch.load(weights_path, map_location="cpu") + weights = {k: v.float().numpy() for k, v in weights.items()} + w = weights + n_layer = 0 + for k in w.keys(): + layer_id = int(k.split(".")[1]) if ("blocks." in k) else 0 + n_layer = max(n_layer, layer_id + 1) + + config = { + "hidden_size": w["emb.weight"].shape[1], + "num_layers": n_layer, + "intermediate_dim": w["blocks.0.ffn.key.weight"].shape[0], + "vocabulary_size": 65536, + "head_size": 64, + } + my_backbone = RWKV7Backbone(**config) + + # Copy layer-1 value-residual params to layer-0 (compatibility) + weights["blocks.0.att.v0"] = weights["blocks.1.att.v0"] + weights["blocks.0.att.v1"] = weights["blocks.1.att.v1"] + weights["blocks.0.att.v2"] = weights["blocks.1.att.v2"] + + my_backbone.get_layer("token_embedding").set_weights( + [weights.pop("emb.weight")] + ) + for i in range(config["num_layers"]): + my_block = my_backbone.get_layer(f"rwkv_layer_{i}") + convert_block(my_block, weights, i) + + my_backbone.output_layer_norm.set_weights( + [ + weights.pop("ln_out.weight"), + weights.pop("ln_out.bias"), + ] + ) + model = RWKV7CausalLM(my_backbone) + my_backbone.head.set_weights([weights.pop("head.weight").T]) + return model + + +# ============================================================================== +# Main Script +# ============================================================================== +url = "https://raw.githubusercontent.com/BlinkDL/RWKV-LM/main/RWKV-v7/rwkv_vocab_v20230424.txt" + + +def main(_): + if not os.path.exists(FLAGS.preset): + os.makedirs(FLAGS.preset) + + souce_model_name = PRESET_MAP[FLAGS.preset] + # Download vocabulary file + + vocabs = requests.get(url, timeout=30).text + with open( + os.path.join(FLAGS.preset, "vocab.txt"), "w", encoding="utf-8" + ) as f: + f.write(vocabs) + tokenizer = RWKVTokenizer() + tokenizer.load_assets(FLAGS.preset) + + # Download checkpoint + download_path = snapshot_download( + repo_id="RWKV/rwkv7-g1", + allow_patterns=souce_model_name, + ) + weights_path = os.path.join(download_path, souce_model_name) + + # Convert to Keras format + my_model = convert_rwkv7_checkpoints(weights_path) + + # Re-build PyTorch reference model + args = types.SimpleNamespace() + args.n_layer = my_model.backbone.num_layers + args.n_embd = my_model.backbone.hidden_size + args.vocab_size = my_model.backbone.vocabulary_size + args.head_size_a = 64 + args.dim_att = args.n_embd + args.dim_ffn = my_model.backbone.intermediate_dim + + if os.environ["CUDA_VISIBLE_DEVICES"] != "-1": + standard_model = RWKV(args).cuda() + else: + standard_model = RWKV(args) + + weights = torch.load(weights_path, map_location="cpu") + # Some parameters are not present in the weights, but this does not matter. + # This is because these parameters are not used + standard_model.load_state_dict(weights, strict=False) + + # Sanity check: tokenize & compare outputs + x = tokenizer(["i love u"]) + x = np.reshape(x, [1, -1]) + my_output = my_model(ops.convert_to_tensor(x, "int32")) + xx = torch.from_numpy(x).int() + if torch.cuda.is_available(): + xx = xx.cuda() + standard_output = standard_model(xx) + + standard_output = standard_output.cpu().float().detach().numpy() + my_output = ops.convert_to_numpy(ops.cast(my_output, "float32")) + + try: + np.testing.assert_allclose(my_output, standard_output, atol=1e-4) + print("Successfully passed the numerical verification! 🎯✅📊") + except AssertionError as err: + print("\n") + print(err.args[0]) + print("\n") + + # Export final Keras model + my_model.backbone.save_to_preset(f"./{FLAGS.preset}") + + +# ============================================================================== +# Entry Guard +# ============================================================================== +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)