Skip to content

Commit e4093be

Browse files
Merge pull request #55 from openai/stream-annotations
[feat] stream annotations as they are added to model response
2 parents 5c528d0 + 103b9e6 commit e4093be

File tree

5 files changed

+175
-31
lines changed

5 files changed

+175
-31
lines changed

chatkit/actions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ActionConfig(BaseModel):
2424

2525
class Action(BaseModel, Generic[TType, TPayload]):
2626
type: TType = Field(default=TType, frozen=True) # pyright: ignore
27-
payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions
27+
payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions
2828

2929
@classmethod
3030
def create(

chatkit/agents.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
from collections import defaultdict
34
from collections.abc import AsyncIterator
45
from datetime import datetime
56
from inspect import cleandoc
@@ -45,6 +46,7 @@
4546
Annotation,
4647
AssistantMessageContent,
4748
AssistantMessageContentPartAdded,
49+
AssistantMessageContentPartAnnotationAdded,
4850
AssistantMessageContentPartDone,
4951
AssistantMessageContentPartTextDelta,
5052
AssistantMessageItem,
@@ -207,9 +209,10 @@ def _complete(self):
207209

208210
def _convert_content(content: Content) -> AssistantMessageContent:
209211
if content.type == "output_text":
210-
annotations = []
211-
for annotation in content.annotations:
212-
annotations.extend(_convert_annotation(annotation))
212+
annotations = [
213+
_convert_annotation(annotation) for annotation in content.annotations
214+
]
215+
annotations = [a for a in annotations if a is not None]
213216
return AssistantMessageContent(
214217
text=content.text,
215218
annotations=annotations,
@@ -221,37 +224,43 @@ def _convert_content(content: Content) -> AssistantMessageContent:
221224
)
222225

223226

224-
def _convert_annotation(
225-
annotation: ResponsesAnnotation,
226-
) -> list[Annotation]:
227+
def _convert_annotation(raw_annotation: object) -> Annotation | None:
227228
# There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
228-
# resulting into annotation being a dict instead of a ResponsesAnnotation
229-
if isinstance(annotation, dict):
230-
annotation = TypeAdapter(ResponsesAnnotation).validate_python(annotation)
229+
# resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
230+
annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
231+
raw_annotation
232+
)
231233

232-
result: list[Annotation] = []
233234
if annotation.type == "file_citation":
234235
filename = annotation.filename
235236
if not filename:
236-
return []
237-
result.append(
238-
Annotation(
239-
source=FileSource(filename=filename, title=filename),
240-
index=annotation.index,
241-
)
237+
return None
238+
239+
return Annotation(
240+
source=FileSource(filename=filename, title=filename),
241+
index=annotation.index,
242242
)
243-
elif annotation.type == "url_citation":
244-
result.append(
245-
Annotation(
246-
source=URLSource(
247-
url=annotation.url,
248-
title=annotation.title,
249-
),
250-
index=annotation.end_index,
251-
)
243+
244+
if annotation.type == "url_citation":
245+
return Annotation(
246+
source=URLSource(
247+
url=annotation.url,
248+
title=annotation.title,
249+
),
250+
index=annotation.end_index,
252251
)
253252

254-
return result
253+
if annotation.type == "container_file_citation":
254+
filename = annotation.filename
255+
if not filename:
256+
return None
257+
258+
return Annotation(
259+
source=FileSource(filename=filename, title=filename),
260+
index=annotation.end_index,
261+
)
262+
263+
return None
255264

256265

257266
T1 = TypeVar("T1")
@@ -349,6 +358,10 @@ async def stream_agent_response(
349358
queue_iterator = _AsyncQueueIterator(context._events)
350359
produced_items = set()
351360
streaming_thought: None | StreamingThoughtTracker = None
361+
# item_id -> content_index -> annotation count
362+
item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
363+
lambda: defaultdict(int)
364+
)
352365

353366
# check if the last item in the thread was a workflow or a client tool call
354367
# if it was a client tool call, check if the second last item was a workflow
@@ -462,7 +475,24 @@ def end_workflow(item: WorkflowItem):
462475
),
463476
)
464477
elif event.type == "response.output_text.annotation.added":
465-
# Ignore annotation-added events; annotations are reflected in the final item content.
478+
annotation = _convert_annotation(event.annotation)
479+
if annotation:
480+
# Manually track annotation indices per content part in case we drop an annotation that
481+
# we can't convert to our internal representation (e.g. missing filename).
482+
annotation_index = item_annotation_count[event.item_id][
483+
event.content_index
484+
]
485+
item_annotation_count[event.item_id][event.content_index] = (
486+
annotation_index + 1
487+
)
488+
yield ThreadItemUpdated(
489+
item_id=event.item_id,
490+
update=AssistantMessageContentPartAnnotationAdded(
491+
content_index=event.content_index,
492+
annotation_index=annotation_index,
493+
annotation=annotation,
494+
),
495+
)
466496
continue
467497
elif event.type == "response.output_item.added":
468498
item = event.item

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "openai-chatkit"
3-
version = "1.1.2"
3+
version = "1.2.2"
44
description = "A ChatKit backend SDK."
55
readme = "README.md"
66
requires-python = ">=3.10"

tests/test_agents.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
ResponseContentPartAddedEvent,
3939
)
4040
from openai.types.responses.response_file_search_tool_call import Result
41+
from openai.types.responses.response_output_text import (
42+
AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
43+
)
4144
from openai.types.responses.response_output_text import (
4245
AnnotationFileCitation as ResponsesAnnotationFileCitation,
4346
)
@@ -64,6 +67,7 @@
6467
Annotation,
6568
AssistantMessageContent,
6669
AssistantMessageContentPartAdded,
70+
AssistantMessageContentPartAnnotationAdded,
6771
AssistantMessageContentPartDone,
6872
AssistantMessageContentPartTextDelta,
6973
AssistantMessageItem,
@@ -790,7 +794,17 @@ async def test_stream_agent_response_maps_events():
790794
sequence_number=3,
791795
),
792796
),
793-
None,
797+
ThreadItemUpdated(
798+
item_id="123",
799+
update=AssistantMessageContentPartAnnotationAdded(
800+
content_index=0,
801+
annotation_index=0,
802+
annotation=Annotation(
803+
source=FileSource(filename="file.txt", title="file.txt"),
804+
index=5,
805+
),
806+
),
807+
),
794808
),
795809
],
796810
)
@@ -810,6 +824,91 @@ async def test_event_mapping(raw_event, expected_event):
810824
assert events == []
811825

812826

827+
async def test_stream_agent_response_emits_annotation_added_events():
828+
context = AgentContext(
829+
previous_response_id=None, thread=thread, store=mock_store, request_context=None
830+
)
831+
result = make_result()
832+
item_id = "item_123"
833+
834+
def add_annotation_event(annotation, sequence_number):
835+
result.add_event(
836+
RawResponsesStreamEvent(
837+
type="raw_response_event",
838+
data=Mock(
839+
type="response.output_text.annotation.added",
840+
annotation=annotation,
841+
content_index=0,
842+
item_id=item_id,
843+
annotation_index=sequence_number,
844+
output_index=0,
845+
sequence_number=sequence_number,
846+
),
847+
)
848+
)
849+
850+
add_annotation_event(
851+
ResponsesAnnotationFileCitation(
852+
type="file_citation",
853+
file_id="file_invalid",
854+
filename="",
855+
index=0,
856+
),
857+
sequence_number=0,
858+
)
859+
add_annotation_event(
860+
ResponsesAnnotationContainerFileCitation(
861+
type="container_file_citation",
862+
container_id="container_1",
863+
file_id="file_123",
864+
filename="container.txt",
865+
start_index=0,
866+
end_index=3,
867+
),
868+
sequence_number=1,
869+
)
870+
add_annotation_event(
871+
ResponsesAnnotationURLCitation(
872+
type="url_citation",
873+
url="https://example.com",
874+
title="Example",
875+
start_index=1,
876+
end_index=5,
877+
),
878+
sequence_number=2,
879+
)
880+
result.done()
881+
882+
events = await all_events(stream_agent_response(context, result))
883+
assert events == [
884+
ThreadItemUpdated(
885+
item_id=item_id,
886+
update=AssistantMessageContentPartAnnotationAdded(
887+
content_index=0,
888+
annotation_index=0,
889+
annotation=Annotation(
890+
source=FileSource(filename="container.txt", title="container.txt"),
891+
index=3,
892+
),
893+
),
894+
),
895+
ThreadItemUpdated(
896+
item_id=item_id,
897+
update=AssistantMessageContentPartAnnotationAdded(
898+
content_index=0,
899+
annotation_index=1,
900+
annotation=Annotation(
901+
source=URLSource(
902+
url="https://example.com",
903+
title="Example",
904+
),
905+
index=5,
906+
),
907+
),
908+
),
909+
]
910+
911+
813912
@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
814913
async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
815914
context = AgentContext(
@@ -942,6 +1041,14 @@ async def test_stream_agent_response_assistant_message_content_types():
9421041
index=0,
9431042
filename="test.txt",
9441043
),
1044+
ResponsesAnnotationContainerFileCitation(
1045+
type="container_file_citation",
1046+
container_id="container_1",
1047+
file_id="f_456",
1048+
filename="container.txt",
1049+
start_index=0,
1050+
end_index=3,
1051+
),
9451052
ResponsesAnnotationURLCitation(
9461053
type="url_citation",
9471054
url="https://www.google.com",
@@ -994,6 +1101,13 @@ async def test_stream_agent_response_assistant_message_content_types():
9941101
),
9951102
index=0,
9961103
),
1104+
Annotation(
1105+
source=FileSource(
1106+
filename="container.txt",
1107+
title="container.txt",
1108+
),
1109+
index=3,
1110+
),
9971111
Annotation(
9981112
source=URLSource(
9991113
url="https://www.google.com",

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)