Skip to content

Commit 553e07c

Browse files
committed
fix: streaming payload
1 parent aba30b4 commit 553e07c

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

workflowai/core/client/api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from typing import Any, AsyncIterator, Literal, Optional, TypeVar, overload
32

43
import httpx
@@ -71,5 +70,4 @@ async def stream(
7170
) as response:
7271
async for chunk in response.aiter_bytes():
7372
stripped = chunk.removeprefix(b"data: ").removesuffix(b"\n\n")
74-
parsed = json.loads(stripped)
75-
yield returns.model_construct(None, **parsed)
73+
yield returns.model_validate_json(stripped)

workflowai/core/client/client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ImportExampleRequest,
1212
ImportRunRequest,
1313
RunRequest,
14+
RunTaskStreamChunk,
1415
TaskRunResponse,
1516
)
1617
from workflowai.core.domain.cache_usage import CacheUsage
@@ -113,9 +114,13 @@ async def run(
113114

114115
return res.to_domain(task)
115116

116-
return self.api.stream(
117-
method="POST", path=route, data=request, returns=task.output_class
118-
)
117+
async def _stream():
118+
async for chunk in self.api.stream(
119+
method="POST", path=route, data=request, returns=RunTaskStreamChunk
120+
):
121+
yield task.output_class.model_construct(None, **chunk.task_output)
122+
123+
return _stream()
119124

120125
async def import_run(
121126
self, run: TaskRun[TaskInput, TaskOutput]

workflowai/core/client/client_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ async def test_stream(self, httpx_mock: HTTPXMock, client: Client):
6565
httpx_mock.add_response(
6666
stream=IteratorStream(
6767
[
68-
b'data: {"message": ""}',
69-
b'data: {"message": "hel"}',
70-
b'data: {"message": "hello"}',
68+
b'data: {"run_id":"1","task_output":{"message":""}}',
69+
b'data: {"run_id":"1","task_output":{"message":"hel"}}',
70+
b'data: {"run_id":"1","task_output":{"message":"hello"}}',
7171
]
7272
)
7373
)

workflowai/core/client/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def from_domain(cls, task_run: TaskRun[TaskInput, TaskOutput]):
8080
)
8181

8282

83+
class RunTaskStreamChunk(BaseModel):
84+
run_id: str
85+
task_output: dict[str, Any]
86+
87+
8388
class TaskRunResponse(BaseModel):
8489
id: str
8590
task_id: str

0 commit comments

Comments
 (0)