Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/aiperf/dataset/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,10 @@ async def _handle_dataset_timing_request(

timing_dataset = []
for conversation_id, conversation in self.dataset.items():
for turn in conversation.turns:
timing_dataset.append((turn.timestamp, conversation_id))
if conversation.turns:
timing_dataset.append(
(conversation.turns[0].timestamp, conversation_id)
)

return DatasetTimingResponse(
service_id=self.service_id,
Expand Down
106 changes: 106 additions & 0 deletions tests/unit/dataset/test_dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,109 @@ async def test_sequential_iterator_wraparound(

finally:
Path(filename).unlink(missing_ok=True)

@pytest.mark.asyncio
@patch("aiperf.common.tokenizer.Tokenizer.from_pretrained")
async def test_dataset_timing_request_for_multi_turn_conversations(
self,
mock_tokenizer_from_pretrained,
create_mooncake_trace_file,
mock_tokenizer_cls,
):
"""Test that dataset timing request returns first turn timestamp for each conversation.

When a dataset has multiple turns per conversation, the timing dataset should:
- Return one entry per conversation (not one per turn)
- Use the first turn's timestamp for scheduling
- All turns within a conversation are sent sequentially after the conversation is scheduled
"""
# Mock the tokenizer to avoid HTTP requests
mock_tokenizer_from_pretrained.return_value = (
mock_tokenizer_cls.from_pretrained("test-model")
)

# Create a file with multi-turn conversations
entries = [
'{"session_id": "sess-1", "timestamp": 0, "input_length": 50, "output_length": 10}',
'{"session_id": "sess-1", "delay": 10000, "input_length": 50, "output_length": 10}',
'{"session_id": "sess-1", "delay": 10000, "input_length": 100, "output_length": 10}',
'{"session_id": "sess-2", "timestamp": 20000, "input_length": 25, "output_length": 20}',
'{"session_id": "sess-2", "delay": 10000, "input_length": 10000, "output_length": 20}',
]
filename = create_mooncake_trace_file(entries)

try:
user_config = UserConfig(
endpoint=EndpointConfig(model_names=["test-model"]),
input=InputConfig(
file=filename, custom_dataset_type=CustomDatasetType.MOONCAKE_TRACE
),
)

service_config = ServiceConfig()
dataset_manager = DatasetManager(service_config, user_config)

await dataset_manager.initialize()

# Configure the dataset to load conversations
await dataset_manager._profile_configure_command(
ProfileConfigureCommand(config=user_config, service_id="test_service")
)

# Request timing data
from aiperf.common.messages import DatasetTimingRequest

timing_response = await dataset_manager._handle_dataset_timing_request(
DatasetTimingRequest(service_id="test_service")
)

# Verify timing dataset structure
assert len(timing_response.timing_data) == 2 # 2 conversations, not 5 turns

# Extract timing data for easier testing
timing_dict = {
conv_id: timestamp for timestamp, conv_id in timing_response.timing_data
}

# Verify session 1 is scheduled at its first turn's timestamp (0)
assert "sess-1" in timing_dict
assert timing_dict["sess-1"] == 0

# Verify session 2 is scheduled at its first turn's timestamp (20000)
assert "sess-2" in timing_dict
assert timing_dict["sess-2"] == 20000

# Verify no duplicate session IDs (one per conversation, not per turn)
session_ids = [conv_id for _, conv_id in timing_response.timing_data]
assert len(session_ids) == len(set(session_ids))

# Test with conversations containing empty turns (should be skipped)
# Manually add a conversation with no turns to dataset
from aiperf.common.models import Conversation

empty_conversation = Conversation(
session_id="empty-session",
turns=[], # Empty turns list
)
dataset_manager.dataset["empty-session"] = empty_conversation

# Request timing data again
timing_response_with_empty = (
await dataset_manager._handle_dataset_timing_request(
DatasetTimingRequest(service_id="test_service")
)
)

# Verify empty conversation is skipped - should still have 2 entries, not 3
assert len(timing_response_with_empty.timing_data) == 2
timing_dict_with_empty = {
conv_id: timestamp
for timestamp, conv_id in timing_response_with_empty.timing_data
}
# Empty session should not be in timing data
assert "empty-session" not in timing_dict_with_empty
assert "sess-1" in timing_dict_with_empty
assert "sess-2" in timing_dict_with_empty

finally:
Path(filename).unlink(missing_ok=True)
Loading