diff --git a/src/aiperf/dataset/dataset_manager.py b/src/aiperf/dataset/dataset_manager.py index cfa3880a3..749358a87 100644 --- a/src/aiperf/dataset/dataset_manager.py +++ b/src/aiperf/dataset/dataset_manager.py @@ -347,8 +347,10 @@ async def _handle_dataset_timing_request( timing_dataset = [] for conversation_id, conversation in self.dataset.items(): - for turn in conversation.turns: - timing_dataset.append((turn.timestamp, conversation_id)) + if conversation.turns: + timing_dataset.append( + (conversation.turns[0].timestamp, conversation_id) + ) return DatasetTimingResponse( service_id=self.service_id, diff --git a/tests/unit/dataset/test_dataset_manager.py b/tests/unit/dataset/test_dataset_manager.py index 5ed46e0a6..d0f80f2d5 100644 --- a/tests/unit/dataset/test_dataset_manager.py +++ b/tests/unit/dataset/test_dataset_manager.py @@ -220,3 +220,109 @@ async def test_sequential_iterator_wraparound( finally: Path(filename).unlink(missing_ok=True) + + @pytest.mark.asyncio + @patch("aiperf.common.tokenizer.Tokenizer.from_pretrained") + async def test_dataset_timing_request_for_multi_turn_conversations( + self, + mock_tokenizer_from_pretrained, + create_mooncake_trace_file, + mock_tokenizer_cls, + ): + """Test that dataset timing request returns first turn timestamp for each conversation. + + When a dataset has multiple turns per conversation, the timing dataset should: + - Return one entry per conversation (not one per turn) + - Use the first turn's timestamp for scheduling + - All turns within a conversation are sent sequentially after the conversation is scheduled + """ + # Mock the tokenizer to avoid HTTP requests + mock_tokenizer_from_pretrained.return_value = ( + mock_tokenizer_cls.from_pretrained("test-model") + ) + + # Create a file with multi-turn conversations + entries = [ + '{"session_id": "sess-1", "timestamp": 0, "input_length": 50, "output_length": 10}', + '{"session_id": "sess-1", "delay": 10000, "input_length": 50, "output_length": 10}', + '{"session_id": "sess-1", "delay": 10000, "input_length": 100, "output_length": 10}', + '{"session_id": "sess-2", "timestamp": 20000, "input_length": 25, "output_length": 20}', + '{"session_id": "sess-2", "delay": 10000, "input_length": 10000, "output_length": 20}', + ] + filename = create_mooncake_trace_file(entries) + + try: + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig( + file=filename, custom_dataset_type=CustomDatasetType.MOONCAKE_TRACE + ), + ) + + service_config = ServiceConfig() + dataset_manager = DatasetManager(service_config, user_config) + + await dataset_manager.initialize() + + # Configure the dataset to load conversations + await dataset_manager._profile_configure_command( + ProfileConfigureCommand(config=user_config, service_id="test_service") + ) + + # Request timing data + from aiperf.common.messages import DatasetTimingRequest + + timing_response = await dataset_manager._handle_dataset_timing_request( + DatasetTimingRequest(service_id="test_service") + ) + + # Verify timing dataset structure + assert len(timing_response.timing_data) == 2 # 2 conversations, not 5 turns + + # Extract timing data for easier testing + timing_dict = { + conv_id: timestamp for timestamp, conv_id in timing_response.timing_data + } + + # Verify session 1 is scheduled at its first turn's timestamp (0) + assert "sess-1" in timing_dict + assert timing_dict["sess-1"] == 0 + + # Verify session 2 is scheduled at its first turn's timestamp (20000) + assert "sess-2" in timing_dict + assert timing_dict["sess-2"] == 20000 + + # Verify no duplicate session IDs (one per conversation, not per turn) + session_ids = [conv_id for _, conv_id in timing_response.timing_data] + assert len(session_ids) == len(set(session_ids)) + + # Test with conversations containing empty turns (should be skipped) + # Manually add a conversation with no turns to dataset + from aiperf.common.models import Conversation + + empty_conversation = Conversation( + session_id="empty-session", + turns=[], # Empty turns list + ) + dataset_manager.dataset["empty-session"] = empty_conversation + + # Request timing data again + timing_response_with_empty = ( + await dataset_manager._handle_dataset_timing_request( + DatasetTimingRequest(service_id="test_service") + ) + ) + + # Verify empty conversation is skipped - should still have 2 entries, not 3 + assert len(timing_response_with_empty.timing_data) == 2 + timing_dict_with_empty = { + conv_id: timestamp + for timestamp, conv_id in timing_response_with_empty.timing_data + } + # Empty session should not be in timing data + assert "empty-session" not in timing_dict_with_empty + assert "sess-1" in timing_dict_with_empty + assert "sess-2" in timing_dict_with_empty + + finally: + Path(filename).unlink(missing_ok=True)