Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/e2e-multi-language.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
e2e-tests:
name: E2E (${{ matrix.language }})
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 40
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -97,6 +97,12 @@ jobs:
python-version: '3.11'
cache: 'pip'

- name: Cache embedding model
uses: actions/cache@v4
with:
path: ~/.cache/huggingface/hub/models--BAAI--bge-small-en-v1.5
key: hf-bge-small-en-v1.5

- name: Install sia-code with dev dependencies
run: |
pip install -e ".[dev]"
Expand Down Expand Up @@ -128,6 +134,11 @@ jobs:
E2E_LANGUAGE: ${{ matrix.language }}
E2E_KEYWORD: ${{ matrix.keyword }}
E2E_SYMBOL: ${{ matrix.symbol }}

- name: Embedding daemon status
if: always()
run: |
sia-code embed status -v

- name: Upload test results
uses: actions/upload-artifact@v4
Expand Down
1 change: 1 addition & 0 deletions sia_code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class IndexingConfig(BaseModel):
)
include_patterns: list[str] = Field(default_factory=lambda: ["**/*"])
max_file_size_mb: int = 5
chunk_batch_size: int = 500

def get_effective_exclude_patterns(self, root: Path) -> list[str]:
"""Get combined exclude patterns from config and .gitignore files.
Expand Down
4 changes: 2 additions & 2 deletions sia_code/embed_server/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def _send_request(self, request: dict) -> dict:
# Send request
sock.sendall(Message.encode(request))

# Receive response (up to 100MB for large batch embeddings)
response_data = sock.recv(100_000_000)
# Receive response using length-prefixed framing
response_data = Message.read_from_socket(sock)
sock.close()

# Parse response
Expand Down
21 changes: 19 additions & 2 deletions sia_code/embed_server/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _handle_connection(self, conn: socket.socket):
conn: Client socket connection
"""
try:
# Read request (up to 10MB)
data = conn.recv(10_000_000)
# Read request using length-prefixed framing
data = Message.read_from_socket(conn)
if not data:
return

Expand Down Expand Up @@ -344,6 +344,23 @@ def start_daemon(
foreground: Run in foreground (don't daemonize)
idle_timeout_seconds: Unload model after this many seconds of inactivity
"""
status = daemon_status(socket_path=socket_path, pid_path=pid_path)
if status.get("running"):
print("Daemon already running")
return

reason = status.get("reason", "")
pid_file = Path(pid_path)
socket_file = Path(socket_path)

if pid_file.exists() and reason in {"Stale PID file", "Error checking PID"}:
pid_file.unlink(missing_ok=True)
if socket_file.exists() and (
reason in {"No PID file", "Stale PID file", "No socket file"}
or reason.startswith("Health check failed")
):
socket_file.unlink(missing_ok=True)

# Setup logging
logging.basicConfig(
level=logging.INFO,
Expand Down
53 changes: 49 additions & 4 deletions sia_code/embed_server/protocol.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,65 @@
"""Protocol for embedding server communication."""

import json
import struct


class Message:
"""Base message class for socket communication."""
"""Base message class for socket communication with length-prefixed framing."""

HEADER_SIZE = 4 # 4 bytes for uint32 big-endian length

@staticmethod
def encode(data: dict) -> bytes:
"""Encode message to JSON bytes with newline delimiter."""
return (json.dumps(data) + "\n").encode("utf-8")
"""Encode message with 4-byte length prefix.

Format: [4-byte length header (big-endian uint32)][JSON payload]
"""
payload = json.dumps(data).encode("utf-8")
header = struct.pack(">I", len(payload))
return header + payload

@staticmethod
def decode(data: bytes) -> dict:
"""Decode JSON bytes to message dict."""
return json.loads(data.decode("utf-8").strip())
return json.loads(data.decode("utf-8"))

@staticmethod
def read_from_socket(sock, max_bytes: int = 50_000_000) -> bytes:
"""Read a length-prefixed message from socket.

Args:
sock: Socket to read from
max_bytes: Maximum message size (default 50MB)

Returns:
Message payload bytes (without the length prefix)

Raises:
ConnectionError: If connection closes unexpectedly
ValueError: If message exceeds max_bytes
"""
# Read 4-byte header
header = b""
while len(header) < Message.HEADER_SIZE:
chunk = sock.recv(Message.HEADER_SIZE - len(header))
if not chunk:
raise ConnectionError("Connection closed while reading header")
header += chunk

msg_len = struct.unpack(">I", header)[0]
if msg_len > max_bytes:
raise ValueError(f"Message size {msg_len} exceeds {max_bytes} limit")

# Read exactly msg_len bytes
data = b""
while len(data) < msg_len:
chunk = sock.recv(min(64 * 1024, msg_len - len(data)))
if not chunk:
raise ConnectionError("Connection closed while reading payload")
data += chunk

return data


class EmbedRequest:
Expand Down
54 changes: 50 additions & 4 deletions sia_code/indexer/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ def index_directory(

stats = self._create_index_stats(len(files))

# Buffer chunks to reduce write overhead
pending_chunks: list = []
batch_size = max(1, self.config.indexing.chunk_batch_size)
if self.backend.embedding_enabled and hasattr(self.backend, "_get_embed_batch_size"):
embed_batch = self.backend._get_embed_batch_size()
batch_size = min(batch_size, max(1, embed_batch * 8))

def flush_chunks() -> None:
if pending_chunks:
self.backend.store_chunks_batch(pending_chunks)
pending_chunks.clear()

# Process each file
for idx, file_path in enumerate(files, 1):
# Update progress
Expand Down Expand Up @@ -193,8 +205,10 @@ def index_directory(
except OSError:
pass

# Store chunks
self.backend.store_chunks_batch(chunks)
# Buffer chunks and flush when threshold reached
pending_chunks.extend(chunks)
if len(pending_chunks) >= batch_size:
flush_chunks()
stats["indexed_files"] += 1
stats["total_chunks"] += len(chunks)
metrics.files_processed += 1
Expand All @@ -211,6 +225,15 @@ def index_directory(
metrics.errors_count += 1
logger.exception(f"Unexpected error indexing {file_path}")

# Flush any remaining chunks
try:
flush_chunks()
except Exception as e:
error_msg = f"Error flushing final chunk batch: {str(e)}"
stats["errors"].append(error_msg)
metrics.errors_count += 1
logger.exception("Error flushing final chunk batch")

# Finalize metrics
metrics.finish()
stats["metrics"] = metrics.to_dict()
Expand Down Expand Up @@ -271,6 +294,18 @@ def index_directory_parallel(
greedy_merge=self.config.chunking.greedy_merge,
)

# Buffer chunks to reduce write overhead
pending_chunks: list = []
batch_size = max(1, self.config.indexing.chunk_batch_size)
if self.backend.embedding_enabled and hasattr(self.backend, "_get_embed_batch_size"):
embed_batch = self.backend._get_embed_batch_size()
batch_size = min(batch_size, max(1, embed_batch * 8))

def flush_chunks() -> None:
if pending_chunks:
self.backend.store_chunks_batch(pending_chunks)
pending_chunks.clear()

# Process files in parallel
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
Expand Down Expand Up @@ -300,8 +335,10 @@ def index_directory_parallel(
# Track metrics
metrics.bytes_processed += file_size

# Store chunks
self.backend.store_chunks_batch(chunks)
# Buffer chunks and flush when threshold reached
pending_chunks.extend(chunks)
if len(pending_chunks) >= batch_size:
flush_chunks()
stats["indexed_files"] += 1
stats["total_chunks"] += len(chunks)
metrics.files_processed += 1
Expand All @@ -319,6 +356,15 @@ def index_directory_parallel(
metrics.errors_count += 1
logger.exception(f"Unexpected error processing {file_path}")

# Flush any remaining chunks
try:
flush_chunks()
except Exception as e:
error_msg = f"Error flushing final chunk batch: {str(e)}"
stats["errors"].append(error_msg)
metrics.errors_count += 1
logger.exception("Error flushing final chunk batch")

# Finalize metrics
metrics.finish()
stats["metrics"] = metrics.to_dict()
Expand Down
Loading
Loading