Skip to content

Commit 367da6d

Browse files
committed
format
1 parent 6403dff commit 367da6d

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

chatkit/actions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ActionConfig(BaseModel):
2424

2525
class Action(BaseModel, Generic[TType, TPayload]):
2626
type: TType = Field(default=TType, frozen=True) # pyright: ignore
27-
payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions
27+
payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions
2828

2929
@classmethod
3030
def create(

chatkit/agents.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
2-
from collections import defaultdict
32
import json
3+
from collections import defaultdict
44
from collections.abc import AsyncIterator
55
from datetime import datetime
66
from inspect import cleandoc
@@ -215,7 +215,9 @@ def _complete(self):
215215

216216
def _convert_content(content: Content) -> AssistantMessageContent:
217217
if content.type == "output_text":
218-
annotations = [_convert_annotation(annotation) for annotation in content.annotations]
218+
annotations = [
219+
_convert_annotation(annotation) for annotation in content.annotations
220+
]
219221
annotations = [a for a in annotations if a is not None]
220222
return AssistantMessageContent(
221223
text=content.text,
@@ -228,27 +230,31 @@ def _convert_content(content: Content) -> AssistantMessageContent:
228230
)
229231

230232

231-
def _convert_annotation(
232-
raw_annotation: object
233-
) -> Annotation | None:
233+
def _convert_annotation(raw_annotation: object) -> Annotation | None:
234234
# There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
235235
# resulting into annotation being a dict.
236236
match raw_annotation:
237-
case AnnotationFileCitation() | AnnotationURLCitation() | AnnotationContainerFileCitation() | AnnotationFilePath():
237+
case (
238+
AnnotationFileCitation()
239+
| AnnotationURLCitation()
240+
| AnnotationContainerFileCitation()
241+
| AnnotationFilePath()
242+
):
238243
annotation = raw_annotation
239244
case _:
240-
annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(raw_annotation)
241-
245+
annotation = TypeAdapter[ResponsesAnnotation](
246+
ResponsesAnnotation
247+
).validate_python(raw_annotation)
242248

243249
if annotation.type == "file_citation":
244250
filename = annotation.filename
245251
if not filename:
246252
return None
247253

248254
return Annotation(
249-
source=FileSource(filename=filename, title=filename),
250-
index=annotation.index,
251-
)
255+
source=FileSource(filename=filename, title=filename),
256+
index=annotation.index,
257+
)
252258

253259
if annotation.type == "url_citation":
254260
return Annotation(
@@ -265,9 +271,9 @@ def _convert_annotation(
265271
return None
266272

267273
return Annotation(
268-
source=FileSource(filename=filename, title=filename),
269-
index=annotation.end_index,
270-
)
274+
source=FileSource(filename=filename, title=filename),
275+
index=annotation.end_index,
276+
)
271277

272278
return None
273279

@@ -368,7 +374,9 @@ async def stream_agent_response(
368374
produced_items = set()
369375
streaming_thought: None | StreamingThoughtTracker = None
370376
# item_id -> content_index -> annotation count
371-
item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(lambda: defaultdict(int))
377+
item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
378+
lambda: defaultdict(int)
379+
)
372380

373381
# check if the last item in the thread was a workflow or a client tool call
374382
# if it was a client tool call, check if the second last item was a workflow
@@ -486,15 +494,19 @@ def end_workflow(item: WorkflowItem):
486494
if annotation:
487495
# Manually track annotation indices per content part in case we drop an annotation that
488496
# we can't convert to our internal representation (e.g. missing filename).
489-
annotation_index = item_annotation_count[event.item_id][event.content_index]
490-
item_annotation_count[event.item_id][event.content_index] = annotation_index + 1
497+
annotation_index = item_annotation_count[event.item_id][
498+
event.content_index
499+
]
500+
item_annotation_count[event.item_id][event.content_index] = (
501+
annotation_index + 1
502+
)
491503
yield ThreadItemUpdated(
492504
item_id=event.item_id,
493505
update=AssistantMessageContentPartAnnotationAdded(
494506
content_index=event.content_index,
495507
annotation_index=annotation_index,
496508
annotation=annotation,
497-
)
509+
),
498510
)
499511
continue
500512
elif event.type == "response.output_item.added":

0 commit comments

Comments
 (0)