From fe5060ebabd5242e894c0d9772befc247926d197 Mon Sep 17 00:00:00 2001 From: Rohit-Andhavarapu Date: Fri, 30 Jan 2026 16:12:43 +0000 Subject: [PATCH 1/3] feat: streaming generation --- pyproject.toml | 11 ++ simply/model_lib.py | 358 +++++++++++++++++++++++++++++++++++++++ simply/model_lib_test.py | 48 ++++++ 3 files changed, 417 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 74d65d8..82eddeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,21 @@ requires-python = ">=3.12" dependencies = [ "absl-py>=2.3.1", "clu>=0.0.12", + "datasets>=4.5.0", + "deprecated>=1.3.1", "einops>=0.8.1", "grain>=0.2.15", "orbax>=0.1.9", + "orbax-checkpoint>=0.11.32", + "pylatexenc>=2.10", + "pytest>=9.0.2", + "sentencepiece>=0.2.1", + "sympy>=1.14.0", "tensorboard>=2.20.0", + "tensorboardx>=2.6.4", + "tensorflow-cpu>=2.20.0", + "tensorflow-datasets>=4.9.9", + "tokenizers>=0.22.2", ] diff --git a/simply/model_lib.py b/simply/model_lib.py index 30f5f05..7e0e92c 100644 --- a/simply/model_lib.py +++ b/simply/model_lib.py @@ -3235,6 +3235,35 @@ def avg_output_score(self) -> float: return np.mean(self.output_token_scores).item() +@sampling_lib.SamplingRegistry.register +@dataclasses.dataclass(frozen=True) +class StreamingToken: + """A single token emitted during streaming generation. + + This class represents one token in a streaming generation sequence, + yielded incrementally as the model generates output. + + Attributes: + token_id: The integer ID of the generated token. + token_text: The decoded text representation of the token. + token_logprob: The log probability of this token under the sampling + distribution. + token_score: The log probability score (unaffected by sampling params). + position: The position of this token in the full sequence (input + output). + is_input: Whether this token is part of the input (vs generated output). + is_eos: Whether this token is an end-of-sequence token. + is_final: Whether this is the last token in the generation. + """ + token_id: int + token_text: str + token_logprob: float + token_score: float + position: int + is_input: bool + is_eos: bool + is_final: bool + + @dataclasses.dataclass(frozen=True) class ScoringParams: temperature: float = 1.0 @@ -3657,6 +3686,230 @@ def generate( return sample_outputs + def generate_stream( + self, + input_text: sampling_lib.SamplingInput, + prng_key: int | PRNGKey | None = None, + params: PyTree = None, + prefill_size: int = -1, + sampling_params: SamplingParams | None = None, + scoring_params: ScoringParams | None = None, + include_eos_in_output: bool = True, + ) -> 'collections.abc.Iterator[StreamingToken]': + """Generate samples from input text, yielding tokens as they are generated. + + This is the streaming version of `generate()`. Instead of waiting for the + full generation to complete, this method yields `StreamingToken` instances + one at a time as the model generates output. This enables real-time + streaming for interactive applications. + + Note: This method only supports single input (not batched), and + num_samples=1. For batched generation, use `generate()` instead. + + Args: + input_text: Single input to generate samples for. Can be either a string + or a sequence of Chunks. + prng_key: A PRNGKey or seed for controlling randomness. + params: Parameters of the model. If None, uses default parameters. + prefill_size: Prefill size for generation. If non-positive, inferred + from sampling params. + sampling_params: Sampling params to use for the generation. + scoring_params: Scoring params to score the generated output. + include_eos_in_output: Whether to yield the EOS token when encountered. + + Yields: + StreamingToken instances containing token information as generated. + + Example: + ```python + lm = LMInterface(model, params, vocab=vocab) + for token in lm.generate_stream("Once upon a time"): + print(token.token_text, end="", flush=True) + if token.is_final: + print() # newline at end + ``` + """ + if params is None: + params = self.model_params + + if prng_key is None: + seed = int(time.time() * 1000) + seed = jax.experimental.multihost_utils.broadcast_one_to_all(seed) + prng_key = jax.random.key(seed=seed) + elif isinstance(prng_key, int): + prng_key = jax.random.key(seed=prng_key) + + if sampling_params is None: + sampling_params = self.default_sampling_params + if prefill_size > 0: + sampling_params = dataclasses.replace( + sampling_params, prefill_size=prefill_size + ) + + if scoring_params is None: + scoring_params = ScoringParams.from_sampling_params(sampling_params) + + # Process input - streaming only supports single input + raw_input = sampling_lib.input_as_chunks(input_text) + unpadded_input = self.input_processor.encode( + raw_input, max_input_len=sampling_params.max_input_len + ) + processed_input = sampling_lib.ProcessedInputBatch.from_unpadded_inputs( + [unpadded_input], pad_id=self.input_processor.pad_id + ) + + decoding_schedule = sampling_params.get_decoding_schedule( + min_input_length=processed_input.min_length, + max_input_length=processed_input.max_length, + ) + + processed_input = processed_input.pad_to( + 1 + + max( + decoding_schedule.get_next_length(processed_input.max_length - 1), + decoding_schedule.prefill_size, + ) + ) + + # Prefill phase + position = decoding_schedule.begin_position + logits, extra_output = self.prefill_fn( + params, + processed_input.token_slice(0, decoding_schedule.prefill_size), + extra_inputs=processed_input.extra_inputs, + position=position, + return_logits=True, + ) + + logits = sharding_lib.with_sharding_constraint( + logits, (('replica', 'data'), 'model', None) + ) + token_scores = sampling_lib.compute_log_likelihood( + logits, + processed_input.token_slice(1, decoding_schedule.prefill_size + 1), + temperature=scoring_params.temperature, + top_k=scoring_params.top_k, + top_p=scoring_params.top_p, + ) + del logits + + # Add dummy score for BOS token + token_scores = pad_along_axis(token_scores, (1, 0), axis=1) + token_logprobs = jnp.zeros_like(token_scores) + + # Initialize sampling state + sampling_state = SamplingState( + prng_key=jnp.copy(prng_key), + position=jnp.array(position), + decode_state=extra_output['decode_state'], + tokens=processed_input.tokens, + token_logprobs=token_logprobs, + token_scores=token_scores, + input_lens=jnp.reshape(processed_input.lengths, [-1, 1]), + max_decode_steps=einops.repeat( + jnp.array(sampling_params.max_decode_steps), + '-> b 1', + b=processed_input.batch_size, + ), + eos_ids=jnp.array(self.input_processor.eos_ids, dtype=jnp.int32), + ) + + input_length = int(processed_input.lengths[0]) + eos_ids_set = set(self.input_processor.eos_ids) + + # JIT compile the single-step decode function + decode_step_fn = jax.jit( + common.named_partial_fn( + decode_one_step, + 'decode_step_fn', + apply_fn=self.model.apply, + ), + ) + + # Decode loop - yield tokens as we go + num_output_tokens = 0 + max_output_tokens = min( + sampling_params.max_decode_steps, + sampling_params.max_seq_len - input_length, + ) + + while position < decoding_schedule.end_position: + # Pad state if needed + sampling_state = self.pad_state_to_fn( + sampling_state, length=decoding_schedule.get_next_length(position) + ) + + # Single decode step + sampling_state, output_token, output_logprob, output_score = ( + decode_step_fn( + params=params, + sampling_state=sampling_state, + extra_inputs=processed_input.extra_inputs, + temperature=sampling_params.temperature, + top_k=sampling_params.top_k, + top_p=sampling_params.top_p, + scoring_temperature=scoring_params.temperature, + scoring_top_k=scoring_params.top_k, + scoring_top_p=scoring_params.top_p, + ) + ) + + position = int(jax.device_get(sampling_state.position)) + current_position = position - 1 # position was incremented + + # Check if we're in output region (past input) + if current_position >= input_length: + token_id = int(jax.device_get(output_token[0, 0])) + token_logprob = float(jax.device_get(output_logprob[0, 0])) + token_score_val = float(jax.device_get(output_score[0, 0])) + + is_eos = token_id in eos_ids_set + num_output_tokens += 1 + is_final = is_eos or num_output_tokens >= max_output_tokens + + # Decode token to text + try: + token_text = self.input_processor.decode([token_id]) + if token_text: + token_text = sampling_lib.chunks_as_text(token_text) + else: + token_text = '' + except Exception: + token_text = '' + + # Skip EOS if not including it + if is_eos and not include_eos_in_output: + # Still yield final marker but with empty text + yield StreamingToken( + token_id=token_id, + token_text='', + token_logprob=token_logprob, + token_score=token_score_val, + position=current_position, + is_input=False, + is_eos=True, + is_final=True, + ) + return + + yield StreamingToken( + token_id=token_id, + token_text=token_text, + token_logprob=token_logprob, + token_score=token_score_val, + position=current_position, + is_input=False, + is_eos=is_eos, + is_final=is_final, + ) + + if is_final: + return + + # Check if all sequences have ended + if jax.device_get(sampling_state.all_has_ended): + return + def score( self, input_text: sampling_lib.SamplingInput, @@ -3905,6 +4158,111 @@ def cond_fn(sampling_state: SamplingState) -> jax.typing.ArrayLike: return final_sampling_state +def decode_one_step( + apply_fn: Callable[..., Array], + params: PyTree, + sampling_state: SamplingState, + extra_inputs: Mapping[str, PyTree] | None = None, + temperature: float = 1.0, + top_k: int = -1, + top_p: float = 1.0, + scoring_temperature: float = 1.0, + scoring_top_k: int = -1, + scoring_top_p: float = 1.0, +) -> tuple[SamplingState, Array, Array, Array]: + """Performs a single decode step for streaming generation. + + This function executes exactly one forward pass of the model and samples + a single token. Unlike `continue_decode`, this function does not use + `jax.lax.while_loop`, allowing Python-level control between decode steps + for streaming output. + + Args: + apply_fn: The model's apply function. + params: The model parameters. + sampling_state: The current sampling state. + extra_inputs: Optional extra inputs for the model. + temperature: Sampling temperature. + top_k: Top-k sampling parameter (-1 for disabled). + top_p: Top-p (nucleus) sampling parameter. + scoring_temperature: Temperature for scoring. + scoring_top_k: Top-k for scoring. + scoring_top_p: Top-p for scoring. + + Returns: + A tuple of (new_sampling_state, output_token, output_logprob, output_score) + where output_token is the sampled token ID [batch, 1], output_logprob is + the log probability under the sampling distribution, and output_score is + the raw log probability. + """ + # Forward pass: logits has shape [batch_size, 1, vocab_size] + logits, extra_output = apply_fn( + params, + sampling_state.input_tokens, + segment_positions=einops.repeat( + sampling_state.position, '-> b 1', b=sampling_state.batch_size + ), + extra_inputs=extra_inputs, + decode_state=sampling_state.decode_state, + ) + + # Sample from logits + prng_key, key = jax.random.split(sampling_state.prng_key, 2) + output_tokens, output_logprobs = sampling_lib.sample_from_logits( + key, logits, temperature=temperature, top_k=top_k, top_p=top_p + ) + + # Handle three cases: ended, output position, or input position + final_output_tokens = jnp.select( + [ + sampling_state.has_ended, + sampling_state.next_position_is_output, + ], + [ + sampling_state.input_tokens, + output_tokens, + ], + default=sampling_state.next_tokens, + ) + + # Compute scores + def _score_fn(logits: Array, tokens: Array) -> Array: + return sampling_lib.compute_log_likelihood( + logits, + tokens, + temperature=scoring_temperature, + top_k=scoring_top_k, + top_p=scoring_top_p, + ) + + scoring_follows_sampling = ( + (scoring_temperature == temperature) + & (scoring_top_k == top_k) + & (scoring_top_p == top_p) + ) + output_scores = jax.lax.cond( + scoring_follows_sampling + & jnp.all(sampling_state.next_position_is_output), + lambda *_: output_logprobs, + _score_fn, + logits, + final_output_tokens, + ) + + # Build new sampling state + new_sampling_state = dataclasses.replace( + sampling_state, + prng_key=prng_key, + position=sampling_state.position + 1, + decode_state=extra_output['decode_state'], + tokens=sampling_state.updated_tokens(final_output_tokens), + token_logprobs=sampling_state.updated_token_logprobs(output_logprobs), + token_scores=sampling_state.updated_token_scores(output_scores), + ) + + return new_sampling_state, final_output_tokens, output_logprobs, output_scores + + ################################################################################ # Utilities diff --git a/simply/model_lib_test.py b/simply/model_lib_test.py index ebeee60..d99c6ba 100644 --- a/simply/model_lib_test.py +++ b/simply/model_lib_test.py @@ -634,6 +634,54 @@ def test_lm_interface_generate(self): self.assertLen(so.input_token_ids, 4) self.assertLen(so.input_token_scores, 3) + def test_lm_interface_generate_stream(self): + """Test streaming generation produces consistent results with generate().""" + vocab = tokenization.TestVocab([str(i) for i in range(60)]) + prng_key = jax.random.key(0) + params = self.tfm_lm.init(prng_key) + max_decode_steps = 10 + sampling_params = model_lib.SamplingParams( + top_k=-1, top_p=1.0, temperature=0.0, # Use temp=0 for determinism + max_decode_steps=max_decode_steps) + lm_interface = model_lib.LMInterface(self.tfm_lm, params, vocab) + + # Get streaming output + streaming_tokens = [] + for token in lm_interface.generate_stream( + input_text='1 2 3', + prng_key=jax.random.key(seed=25), + sampling_params=sampling_params, + ): + streaming_tokens.append(token) + # Verify token structure + self.assertIsInstance(token.token_id, int) + self.assertIsInstance(token.token_text, str) + self.assertIsInstance(token.token_logprob, float) + self.assertIsInstance(token.token_score, float) + self.assertIsInstance(token.position, int) + self.assertFalse(token.is_input) # All yielded tokens are output + + # Should have output tokens + self.assertGreater(len(streaming_tokens), 0) + + # Check is_final is set correctly on last token + if streaming_tokens: + self.assertTrue(streaming_tokens[-1].is_final) + + # Verify token IDs match non-streaming generation + outputs = lm_interface.generate( + input_text='1 2 3', + prng_key=jax.random.key(seed=25), + sampling_params=sampling_params, + ) + outputs = cast(list[model_lib.SamplingOutput], outputs) + streaming_token_ids = [t.token_id for t in streaming_tokens] + + self.assertEqual( + streaming_token_ids[:len(outputs[0].output_token_ids)], + outputs[0].output_token_ids[:len(streaming_token_ids)], + ) + def test_lm_interface_batch(self): vocab = tokenization.TestVocab([str(i) for i in range(20)]) prng_key = jax.random.key(0) From 758997a0bb7737a3f800a57372f1583d6dd8c9d2 Mon Sep 17 00:00:00 2001 From: Rohit-Andhavarapu Date: Fri, 30 Jan 2026 16:20:33 +0000 Subject: [PATCH 2/3] fix: edge case for end condition of EOS --- simply/model_lib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/simply/model_lib.py b/simply/model_lib.py index 7e0e92c..66efdfd 100644 --- a/simply/model_lib.py +++ b/simply/model_lib.py @@ -3865,7 +3865,9 @@ def generate_stream( is_eos = token_id in eos_ids_set num_output_tokens += 1 - is_final = is_eos or num_output_tokens >= max_output_tokens + # Mark as final if: eos, max tokens reached, or loop will exit + will_loop_exit = position >= decoding_schedule.end_position + is_final = is_eos or num_output_tokens >= max_output_tokens or will_loop_exit # Decode token to text try: From c610cf2c854d2b7348c44d9e858fe773916ee4ca Mon Sep 17 00:00:00 2001 From: Rohit-Andhavarapu Date: Thu, 5 Feb 2026 15:54:52 +0530 Subject: [PATCH 3/3] cla rechecking --- simply/data_lib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/simply/data_lib.py b/simply/data_lib.py index b650c4e..2e60b9c 100644 --- a/simply/data_lib.py +++ b/simply/data_lib.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Data pipeline for LLM training based on grain. Usage: