Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/modalities/inference/text/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class TextInferenceComponentConfig(BaseModel):
temperature: Optional[float] = 1.0
eod_token: Optional[str] = "<eod>"
device: PydanticPytorchDeviceType
system_prompt_path: Optional[str] = ""
chat_template: str
prompt_template: str

@field_validator("device", mode="before")
def parse_device(cls, device) -> PydanticPytorchDeviceType:
Expand Down
82 changes: 72 additions & 10 deletions src/modalities/inference/text/inference_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def __init__(
self,
model: nn.Module,
tokenizer: TokenizerWrapper,
system_prompt_path: str,
chat_template: str,
prompt_template: str,
sequence_length: int,
temperature: float,
Expand All @@ -24,10 +26,28 @@ def __init__(
self.model.eval()
self.tokenizer = tokenizer
self.eod_token = eod_token
self.chat_template = chat_template
self.prompt_template = prompt_template
self.temperature = temperature
self.sequence_length = sequence_length
self.device = device
self.system_prompt = self._load_system_prompt(system_prompt_path)

def _load_system_prompt(self, system_prompt_path: str) -> str:
if not system_prompt_path:
print("ℹ️ No system prompt file specified")
return ""
try:
with open(system_prompt_path, "r", encoding="utf-8") as f:
content = f.read().strip()
print(f"✅ Loaded system prompt from: {system_prompt_path}")
return content
except FileNotFoundError:
print(f"⚠️ System prompt file not found: {system_prompt_path}, using empty prompt")
return ""
except Exception as e:
print(f"❌ Error loading system prompt: {e}, using empty prompt")
return ""

def generate_tokens(
self,
Expand All @@ -38,11 +58,15 @@ def generate_tokens(
input_token_ids = torch.IntTensor(token_ids_list).to(self.device).unsqueeze(0)
input_dict = {"input_ids": input_token_ids}

print("--------------------PROMPT--------------------")
print("\n" + "=" * 60)
print("🤖 PROMPT")
print("=" * 60)
context_decoded = self.tokenizer.decode(token_ids_list)
print("Prompt: ", context_decoded, end="")
print(context_decoded)

print("\n\n--------------------OUTPUT--------------------\n")
print("\n" + "=" * 60)
print("💬 RESPONSE")
print("=" * 60)
generated_token_ids = []
generated_text_old = ""
for _ in range(max_new_tokens):
Expand All @@ -61,7 +85,8 @@ def generate_tokens(
generated_text_new = self.tokenizer.decode(generated_token_ids)

if idx_next_str == self.eod_token:
print("\n<reached end of document token>", end="")
print("\n\n" + "─" * 40)
print("✅ Reached end of document token")
break
else:
diff_text = generated_text_new[len(generated_text_old) :]
Expand All @@ -71,14 +96,51 @@ def generate_tokens(
token_ids_list.append(token_id)
input_token_ids = torch.IntTensor(token_ids_list).to(self.device).unsqueeze(0)
input_dict = {"input_ids": input_token_ids}
print("\n max tokens reached", end="")
else:
print("\n\n" + "─" * 40)
print("⚠️ Maximum tokens reached")

def run(self):
prompt = TextInferenceComponent._get_prompt(self.prompt_template)
try:
self.generate_tokens(context=prompt)
except KeyboardInterrupt:
print("closing app...")
print("\n" + "🚀 Modalities Chat Interface ".center(60, "="))
print("=" * 60)

while True:
try:
user_prompt = self._get_prompt(self.prompt_template)
full_prompt = self.chat_template.format(system_prompt=self.system_prompt, user_prompt=user_prompt)

temp_input = input("\n🌡️ Enter temperatures (comma-separated) or press Enter for default [0.8]: ")

if not temp_input.strip():
temperatures = [0.8]
print("Using default temperature: 0.8")
else:
try:
temperatures = [float(t.strip()) for t in temp_input.split(",")]
if not temperatures:
raise ValueError("No temperatures provided.")
except ValueError:
print("\n❌ Invalid input. Please enter comma-separated numbers or press Enter for default.\n")
continue

for i, temp in enumerate(temperatures):
if len(temperatures) > 1:
print(f"\n\n{'🎯 GENERATION ' + str(i+1) + f' (Temperature: {temp})'.center(60, '=')}")
else:
print(f"\n\n{'🎯 GENERATING (Temperature: ' + str(temp) + ')'.center(60, '=')}")
self.temperature = temp
try:
self.generate_tokens(context=full_prompt)
except KeyboardInterrupt:
print("\n\n👋 Generation interrupted by user.")
continue

print("\n\n" + "🏁 ALL GENERATIONS COMPLETE".center(60, "="))
print("=" * 60)

except KeyboardInterrupt:
print("\n\n👋 Closing app... Goodbye!")
break

@staticmethod
def _get_prompt(template: str) -> str:
Expand Down
144 changes: 144 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import tempfile
from io import StringIO
from pathlib import Path
from unittest.mock import patch

import pytest
import torch

from modalities.config.config import load_app_config_dict
from modalities.inference.text.inference_component import TextInferenceComponent
from modalities.models.utils import ModelTypeEnum, get_model_from_config
from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer


@pytest.fixture
def gpt2_model_and_tokenizer():
config_file_path = Path("tests/test_yaml_configs/gpt2_config_optimizer.yaml")
config_dict = load_app_config_dict(config_file_path=config_file_path)
model = get_model_from_config(config=config_dict, model_type=ModelTypeEnum.MODEL)
tokenizer_path = Path("data/tokenizer/hf_gpt2")
tokenizer = PreTrainedHFTokenizer(
pretrained_model_name_or_path=tokenizer_path, max_length=None, truncation=None, padding=False
)
return model, tokenizer


@pytest.fixture
def temp_system_prompt_file():
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
f.write("You are a helpful AI assistant.")
temp_path = f.name
yield temp_path
Path(temp_path).unlink()


class TestTextInferenceComponent:
# Test actual inference with real model and tokenizer
def test_actual_inference_greedy_decoding(self, gpt2_model_and_tokenizer):
"""Test greedy decoding with real model produces deterministic output."""
model, tokenizer = gpt2_model_and_tokenizer

component = TextInferenceComponent(
model=model,
tokenizer=tokenizer,
system_prompt_path="",
chat_template="{user_prompt}",
prompt_template="{text}",
sequence_length=20,
temperature=0.0, # Greedy decoding
eod_token="<|endoftext|>",
device=torch.device("cpu"),
)

# Run inference twice with same input
outputs = []
for _ in range(2):
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
component.generate_tokens("The weather today is")
outputs.append(mock_stdout.getvalue())
assert outputs[0] == outputs[1], "Greedy decoding should produce deterministic outputs"

def test_actual_inference_with_different_temperatures(self, gpt2_model_and_tokenizer):
"""Test that different temperatures produce different outputs."""
model, tokenizer = gpt2_model_and_tokenizer

def run_inference_with_temperature(temp):
component = TextInferenceComponent(
model=model,
tokenizer=tokenizer,
system_prompt_path="",
chat_template="{user_prompt}",
prompt_template="{text}",
sequence_length=15,
temperature=temp,
eod_token="<|endoftext|>",
device=torch.device("cpu"),
)

with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
component.generate_tokens("Hello world")
return mock_stdout.getvalue()

# Should produce different outputs
torch.manual_seed(42)
output_greedy = run_inference_with_temperature(0.0)
torch.manual_seed(42)
output_sampling = run_inference_with_temperature(1.0)

assert output_greedy != output_sampling, "Greedy and sampling outputs should be different"

def test_run_method_multiple_temperatures(self, gpt2_model_and_tokenizer, temp_system_prompt_file):
"""Test the run() method with multiple temperature inputs."""
model, tokenizer = gpt2_model_and_tokenizer

component = TextInferenceComponent(
model=model,
tokenizer=tokenizer,
system_prompt_path=temp_system_prompt_file,
chat_template="System: {system_prompt}\nUser: {user_prompt}\nAssistant:",
prompt_template="What is {topic}?",
sequence_length=20,
temperature=0.7,
eod_token="<|endoftext|>",
device=torch.device("cpu"),
)
mock_inputs = ["science", "0.3, 0.8, 1.2", KeyboardInterrupt()]
with patch("builtins.input", side_effect=mock_inputs):
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
component.run()

output = mock_stdout.getvalue()

# Verify multiple generations occurred
assert "(Temperature: 0.3)" in output
assert "(Temperature: 0.8)" in output
assert "(Temperature: 1.2)" in output
assert "🏁 ALL GENERATIONS COMPLETE" in output

def test_run_method_default_temperature(self, gpt2_model_and_tokenizer, temp_system_prompt_file):
"""Test the run() method with default temperature (empty input)."""
model, tokenizer = gpt2_model_and_tokenizer

component = TextInferenceComponent(
model=model,
tokenizer=tokenizer,
system_prompt_path=temp_system_prompt_file,
chat_template="System: {system_prompt}\nUser: {user_prompt}\nAssistant:",
prompt_template="What is {topic}?",
sequence_length=20,
temperature=0.7,
eod_token="<|endoftext|>",
device=torch.device("cpu"),
)

mock_inputs = ["machine learning", "", KeyboardInterrupt()]

with patch("builtins.input", side_effect=mock_inputs):
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
component.run()

output = mock_stdout.getvalue()

# Verify default temperature was used
assert "Using default temperature: 0.8" in output