diff --git a/src/modalities/inference/text/config.py b/src/modalities/inference/text/config.py index 3f5f2e5a2..8b1310fe5 100644 --- a/src/modalities/inference/text/config.py +++ b/src/modalities/inference/text/config.py @@ -18,6 +18,9 @@ class TextInferenceComponentConfig(BaseModel): temperature: Optional[float] = 1.0 eod_token: Optional[str] = "" device: PydanticPytorchDeviceType + system_prompt_path: Optional[str] = "" + chat_template: str + prompt_template: str @field_validator("device", mode="before") def parse_device(cls, device) -> PydanticPytorchDeviceType: diff --git a/src/modalities/inference/text/inference_component.py b/src/modalities/inference/text/inference_component.py index 939ccadc0..1bfd69531 100644 --- a/src/modalities/inference/text/inference_component.py +++ b/src/modalities/inference/text/inference_component.py @@ -13,6 +13,8 @@ def __init__( self, model: nn.Module, tokenizer: TokenizerWrapper, + system_prompt_path: str, + chat_template: str, prompt_template: str, sequence_length: int, temperature: float, @@ -24,10 +26,28 @@ def __init__( self.model.eval() self.tokenizer = tokenizer self.eod_token = eod_token + self.chat_template = chat_template self.prompt_template = prompt_template self.temperature = temperature self.sequence_length = sequence_length self.device = device + self.system_prompt = self._load_system_prompt(system_prompt_path) + + def _load_system_prompt(self, system_prompt_path: str) -> str: + if not system_prompt_path: + print("ℹ️ No system prompt file specified") + return "" + try: + with open(system_prompt_path, "r", encoding="utf-8") as f: + content = f.read().strip() + print(f"✅ Loaded system prompt from: {system_prompt_path}") + return content + except FileNotFoundError: + print(f"⚠️ System prompt file not found: {system_prompt_path}, using empty prompt") + return "" + except Exception as e: + print(f"❌ Error loading system prompt: {e}, using empty prompt") + return "" def generate_tokens( self, @@ -38,11 +58,15 @@ def generate_tokens( input_token_ids = torch.IntTensor(token_ids_list).to(self.device).unsqueeze(0) input_dict = {"input_ids": input_token_ids} - print("--------------------PROMPT--------------------") + print("\n" + "=" * 60) + print("🤖 PROMPT") + print("=" * 60) context_decoded = self.tokenizer.decode(token_ids_list) - print("Prompt: ", context_decoded, end="") + print(context_decoded) - print("\n\n--------------------OUTPUT--------------------\n") + print("\n" + "=" * 60) + print("💬 RESPONSE") + print("=" * 60) generated_token_ids = [] generated_text_old = "" for _ in range(max_new_tokens): @@ -61,7 +85,8 @@ def generate_tokens( generated_text_new = self.tokenizer.decode(generated_token_ids) if idx_next_str == self.eod_token: - print("\n", end="") + print("\n\n" + "─" * 40) + print("✅ Reached end of document token") break else: diff_text = generated_text_new[len(generated_text_old) :] @@ -71,14 +96,51 @@ def generate_tokens( token_ids_list.append(token_id) input_token_ids = torch.IntTensor(token_ids_list).to(self.device).unsqueeze(0) input_dict = {"input_ids": input_token_ids} - print("\n max tokens reached", end="") + else: + print("\n\n" + "─" * 40) + print("⚠️ Maximum tokens reached") def run(self): - prompt = TextInferenceComponent._get_prompt(self.prompt_template) - try: - self.generate_tokens(context=prompt) - except KeyboardInterrupt: - print("closing app...") + print("\n" + "🚀 Modalities Chat Interface ".center(60, "=")) + print("=" * 60) + + while True: + try: + user_prompt = self._get_prompt(self.prompt_template) + full_prompt = self.chat_template.format(system_prompt=self.system_prompt, user_prompt=user_prompt) + + temp_input = input("\n🌡️ Enter temperatures (comma-separated) or press Enter for default [0.8]: ") + + if not temp_input.strip(): + temperatures = [0.8] + print("Using default temperature: 0.8") + else: + try: + temperatures = [float(t.strip()) for t in temp_input.split(",")] + if not temperatures: + raise ValueError("No temperatures provided.") + except ValueError: + print("\n❌ Invalid input. Please enter comma-separated numbers or press Enter for default.\n") + continue + + for i, temp in enumerate(temperatures): + if len(temperatures) > 1: + print(f"\n\n{'🎯 GENERATION ' + str(i+1) + f' (Temperature: {temp})'.center(60, '=')}") + else: + print(f"\n\n{'🎯 GENERATING (Temperature: ' + str(temp) + ')'.center(60, '=')}") + self.temperature = temp + try: + self.generate_tokens(context=full_prompt) + except KeyboardInterrupt: + print("\n\n👋 Generation interrupted by user.") + continue + + print("\n\n" + "🏁 ALL GENERATIONS COMPLETE".center(60, "=")) + print("=" * 60) + + except KeyboardInterrupt: + print("\n\n👋 Closing app... Goodbye!") + break @staticmethod def _get_prompt(template: str) -> str: diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 000000000..d6f0079e0 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,144 @@ +import tempfile +from io import StringIO +from pathlib import Path +from unittest.mock import patch + +import pytest +import torch + +from modalities.config.config import load_app_config_dict +from modalities.inference.text.inference_component import TextInferenceComponent +from modalities.models.utils import ModelTypeEnum, get_model_from_config +from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer + + +@pytest.fixture +def gpt2_model_and_tokenizer(): + config_file_path = Path("tests/test_yaml_configs/gpt2_config_optimizer.yaml") + config_dict = load_app_config_dict(config_file_path=config_file_path) + model = get_model_from_config(config=config_dict, model_type=ModelTypeEnum.MODEL) + tokenizer_path = Path("data/tokenizer/hf_gpt2") + tokenizer = PreTrainedHFTokenizer( + pretrained_model_name_or_path=tokenizer_path, max_length=None, truncation=None, padding=False + ) + return model, tokenizer + + +@pytest.fixture +def temp_system_prompt_file(): + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("You are a helpful AI assistant.") + temp_path = f.name + yield temp_path + Path(temp_path).unlink() + + +class TestTextInferenceComponent: + # Test actual inference with real model and tokenizer + def test_actual_inference_greedy_decoding(self, gpt2_model_and_tokenizer): + """Test greedy decoding with real model produces deterministic output.""" + model, tokenizer = gpt2_model_and_tokenizer + + component = TextInferenceComponent( + model=model, + tokenizer=tokenizer, + system_prompt_path="", + chat_template="{user_prompt}", + prompt_template="{text}", + sequence_length=20, + temperature=0.0, # Greedy decoding + eod_token="<|endoftext|>", + device=torch.device("cpu"), + ) + + # Run inference twice with same input + outputs = [] + for _ in range(2): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + component.generate_tokens("The weather today is") + outputs.append(mock_stdout.getvalue()) + assert outputs[0] == outputs[1], "Greedy decoding should produce deterministic outputs" + + def test_actual_inference_with_different_temperatures(self, gpt2_model_and_tokenizer): + """Test that different temperatures produce different outputs.""" + model, tokenizer = gpt2_model_and_tokenizer + + def run_inference_with_temperature(temp): + component = TextInferenceComponent( + model=model, + tokenizer=tokenizer, + system_prompt_path="", + chat_template="{user_prompt}", + prompt_template="{text}", + sequence_length=15, + temperature=temp, + eod_token="<|endoftext|>", + device=torch.device("cpu"), + ) + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + component.generate_tokens("Hello world") + return mock_stdout.getvalue() + + # Should produce different outputs + torch.manual_seed(42) + output_greedy = run_inference_with_temperature(0.0) + torch.manual_seed(42) + output_sampling = run_inference_with_temperature(1.0) + + assert output_greedy != output_sampling, "Greedy and sampling outputs should be different" + + def test_run_method_multiple_temperatures(self, gpt2_model_and_tokenizer, temp_system_prompt_file): + """Test the run() method with multiple temperature inputs.""" + model, tokenizer = gpt2_model_and_tokenizer + + component = TextInferenceComponent( + model=model, + tokenizer=tokenizer, + system_prompt_path=temp_system_prompt_file, + chat_template="System: {system_prompt}\nUser: {user_prompt}\nAssistant:", + prompt_template="What is {topic}?", + sequence_length=20, + temperature=0.7, + eod_token="<|endoftext|>", + device=torch.device("cpu"), + ) + mock_inputs = ["science", "0.3, 0.8, 1.2", KeyboardInterrupt()] + with patch("builtins.input", side_effect=mock_inputs): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + component.run() + + output = mock_stdout.getvalue() + + # Verify multiple generations occurred + assert "(Temperature: 0.3)" in output + assert "(Temperature: 0.8)" in output + assert "(Temperature: 1.2)" in output + assert "🏁 ALL GENERATIONS COMPLETE" in output + + def test_run_method_default_temperature(self, gpt2_model_and_tokenizer, temp_system_prompt_file): + """Test the run() method with default temperature (empty input).""" + model, tokenizer = gpt2_model_and_tokenizer + + component = TextInferenceComponent( + model=model, + tokenizer=tokenizer, + system_prompt_path=temp_system_prompt_file, + chat_template="System: {system_prompt}\nUser: {user_prompt}\nAssistant:", + prompt_template="What is {topic}?", + sequence_length=20, + temperature=0.7, + eod_token="<|endoftext|>", + device=torch.device("cpu"), + ) + + mock_inputs = ["machine learning", "", KeyboardInterrupt()] + + with patch("builtins.input", side_effect=mock_inputs): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + component.run() + + output = mock_stdout.getvalue() + + # Verify default temperature was used + assert "Using default temperature: 0.8" in output