Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions bindu/server/scheduler/memory_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations as _annotations

import math
from collections.abc import AsyncIterator
from contextlib import AsyncExitStack
from typing import Any
Expand All @@ -24,6 +23,10 @@

logger = get_logger("bindu.server.scheduler.memory_scheduler")

# Bounded buffer prevents unbounded memory growth while allowing the API
# handler to enqueue a task before the worker loop is ready to receive.
_TASK_QUEUE_BUFFER_SIZE = 100


def _get_trace_context() -> tuple[str | None, str | None]:
"""Extract primitive trace context from the live OpenTelemetry span."""
Expand All @@ -46,12 +49,11 @@ async def __aenter__(self):
self.aexit_stack = AsyncExitStack()
await self.aexit_stack.__aenter__()

# FIX: Added math.inf to create a buffered stream.
# Without this, the stream defaults to 0 (unbuffered), which causes
# the API server to deadlock/hang waiting for a worker to receive the task.
# Bounded buffer allows the API handler to enqueue tasks before the
# worker loop is ready while preventing unbounded memory growth.
self._write_stream, self._read_stream = anyio.create_memory_object_stream[
TaskOperation
](math.inf)
](_TASK_QUEUE_BUFFER_SIZE)
await self.aexit_stack.enter_async_context(self._read_stream)
await self.aexit_stack.enter_async_context(self._write_stream)

Expand Down
48 changes: 45 additions & 3 deletions bindu/server/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@

import anyio
from opentelemetry.trace import get_tracer, use_span
from opentelemetry.trace.span import (
INVALID_SPAN_CONTEXT,
NonRecordingSpan,
SpanContext,
TraceFlags,
)

from bindu.common.protocol.types import Artifact, Message, TaskIdParams, TaskSendParams
from bindu.server.scheduler.base import Scheduler
Expand All @@ -37,6 +43,38 @@
logger = get_logger(__name__)


def _reconstruct_span(
trace_id: str | None, span_id: str | None
) -> NonRecordingSpan:
"""Reconstruct a NonRecordingSpan from serialized trace_id/span_id strings.

Used to restore OpenTelemetry trace context after the scheduler serializes
the span into primitive strings (required for Redis JSON serialization).

Args:
trace_id: Hex-encoded trace ID (32 chars) or None
span_id: Hex-encoded span ID (16 chars) or None

Returns:
A NonRecordingSpan that carries the trace context for correlation
"""
if trace_id and span_id:
try:
ctx = SpanContext(
trace_id=int(trace_id, 16),
span_id=int(span_id, 16),
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return NonRecordingSpan(ctx)
except (ValueError, TypeError):
logger.warning(
f"Invalid trace context: trace_id={trace_id}, span_id={span_id}"
)
# Return a no-op span with invalid context as fallback
return NonRecordingSpan(INVALID_SPAN_CONTEXT)


@dataclass
class Worker(ABC):
"""Abstract base worker for A2A protocol task execution.
Expand Down Expand Up @@ -104,7 +142,7 @@ async def _handle_task_operation(self, task_operation: dict[str, Any]) -> None:
"""Dispatch task operation to appropriate handler.

Args:
task_operation: Operation dict with 'operation', 'params', and '_current_span'
task_operation: Operation dict with 'operation', 'params', 'trace_id', 'span_id'

Supported Operations:
- run: Execute a task
Expand All @@ -124,8 +162,12 @@ async def _handle_task_operation(self, task_operation: dict[str, Any]) -> None:
}

try:
# Preserve trace context from scheduler
with use_span(task_operation["_current_span"]):
# Reconstruct trace context from serialized trace_id/span_id
span = _reconstruct_span(
task_operation.get("trace_id"),
task_operation.get("span_id"),
)
with use_span(span):
with tracer.start_as_current_span(
f"{task_operation['operation']} task",
attributes={"logfire.tags": ["bindu"]},
Expand Down
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,27 @@
ot_trace = ModuleType("opentelemetry.trace")


class _SpanContext:
"""Mock SpanContext for testing."""

def __init__(self, trace_id=0, span_id=0, is_remote=False, trace_flags=None):
self.trace_id = trace_id
self.span_id = span_id
self.is_remote = is_remote
self.trace_flags = trace_flags or _TraceFlags(0)
self.is_valid = trace_id != 0 and span_id != 0


class _TraceFlags(int):
"""Mock TraceFlags for testing."""

SAMPLED = 1


class _Span:
def __init__(self, context=None):
self._context = context or _SpanContext()

def is_recording(self):
return True

Expand All @@ -24,6 +44,16 @@ def set_attribute(self, *args, **kwargs): # noqa: D401
def set_status(self, *args, **kwargs): # noqa: D401
return None

def get_span_context(self):
return self._context


# NonRecordingSpan is the same as _Span for testing
_NonRecordingSpan = _Span

# Invalid span context constant
_INVALID_SPAN_CONTEXT = _SpanContext(trace_id=0, span_id=0, is_remote=False)


def get_current_span(): # noqa: D401
"""Return a mock span for testing without OpenTelemetry."""
Expand Down Expand Up @@ -63,6 +93,14 @@ def __init__(self, *args, **kwargs): # noqa: D401, ARG002
ot_trace.Span = _Span # type: ignore[attr-defined]
ot_trace.use_span = lambda span: _SpanCtx() # type: ignore[attr-defined]

# --- OpenTelemetry trace.span submodule stub ---
ot_trace_span = ModuleType("opentelemetry.trace.span")
ot_trace_span.NonRecordingSpan = _NonRecordingSpan # type: ignore[attr-defined]
ot_trace_span.SpanContext = _SpanContext # type: ignore[attr-defined]
ot_trace_span.TraceFlags = _TraceFlags # type: ignore[attr-defined]
ot_trace_span.INVALID_SPAN_CONTEXT = _INVALID_SPAN_CONTEXT # type: ignore[attr-defined]
ot_trace.span = ot_trace_span # type: ignore[attr-defined]

# Build minimal opentelemetry root and metrics stub
op_root = ModuleType("opentelemetry")

Expand Down Expand Up @@ -107,6 +145,7 @@ def get_meter(name: str): # noqa: D401, ARG001

sys.modules["opentelemetry"] = op_root
sys.modules["opentelemetry.trace"] = ot_trace
sys.modules["opentelemetry.trace.span"] = ot_trace_span
sys.modules["opentelemetry.metrics"] = metrics_mod


Expand Down