Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions api/app_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import time

from opentelemetry.trace import get_current_span

from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
Expand All @@ -26,8 +28,25 @@ def before_request():
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()

# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response

# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header

return dify_app

Expand Down
5 changes: 4 additions & 1 deletion api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,10 @@ class LoggingConfig(BaseSettings):

LOG_FORMAT: str = Field(
description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
)

LOG_DATEFORMAT: str | None = Field(
Expand Down
8 changes: 6 additions & 2 deletions api/extensions/ext_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")


def init_app(app: DifyApp):
Expand All @@ -25,6 +26,7 @@ def init_app(app: DifyApp):
service_api_bp,
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(service_api_bp)

Expand All @@ -34,7 +36,7 @@ def init_app(app: DifyApp):
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(web_bp)

Expand All @@ -44,14 +46,15 @@ def init_app(app: DifyApp):
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(console_app_bp)

CORS(
files_bp,
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(files_bp)

Expand All @@ -63,5 +66,6 @@ def init_app(app: DifyApp):
trigger_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(trigger_bp)
5 changes: 5 additions & 0 deletions api/extensions/ext_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import flask

from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp


Expand Down Expand Up @@ -76,14 +77,18 @@ class RequestIdFilter(logging.Filter):
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
return True


class RequestIdFormatter(logging.Formatter):
def format(self, record):
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
return super().format(record)


Expand Down
42 changes: 39 additions & 3 deletions api/extensions/ext_request_logging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import logging
import time

import flask
import werkzeug.http
from flask import Flask
from flask import Flask, g
from flask.signals import request_finished, request_started

from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +22,9 @@ def _is_content_type_json(content_type: str) -> bool:

def _log_request_started(_sender, **_extra):
"""Log the start of a request."""
# Record start time for access logging
g.__request_started_ts = time.perf_counter()

if not logger.isEnabledFor(logging.DEBUG):
return

Expand All @@ -42,8 +47,39 @@ def _log_request_started(_sender, **_extra):


def _log_request_finished(_sender, response, **_extra):
"""Log the end of a request."""
if not logger.isEnabledFor(logging.DEBUG) or response is None:
"""Log the end of a request.

Safe to call with or without an active Flask request context.
"""
if response is None:
return

# Always emit a compact access line at INFO with trace_id so it can be grepped
has_ctx = flask.has_request_context()
start_ts = getattr(g, "__request_started_ts", None) if has_ctx else None
duration_ms = None
if start_ts is not None:
duration_ms = round((time.perf_counter() - start_ts) * 1000, 3)

# Request attributes are available only when a request context exists
if has_ctx:
req_method = flask.request.method
req_path = flask.request.path
else:
req_method = "-"
req_path = "-"

trace_id = get_trace_id_from_otel_context() or response.headers.get("X-Trace-Id") or ""
logger.info(
"%s %s %s %s %s",
req_method,
req_path,
getattr(response, "status_code", "-"),
duration_ms if duration_ms is not None else "-",
trace_id,
)

if not logger.isEnabledFor(logging.DEBUG):
return

if not _is_content_type_json(response.content_type):
Expand Down
59 changes: 59 additions & 0 deletions api/tests/unit_tests/extensions/test_ext_request_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,62 @@ def test_when_request_logging_enabled(self, enable_request_logging):
)
assert response.text == _RESPONSE_NEEDLE
assert response.status_code == 200


class TestRequestFinishedInfoAccessLine:
def test_info_access_log_includes_method_path_status_duration_trace_id(self, monkeypatch, caplog):
"""Ensure INFO access line contains expected fields with computed duration and trace id."""
app = _get_test_app()
# Push a real request context so flask.request and g are available
with app.test_request_context("/foo", method="GET"):
# Seed start timestamp via the extension's own start hook and control perf_counter deterministically
seq = iter([100.0, 100.123456])
monkeypatch.setattr(ext_request_logging.time, "perf_counter", lambda: next(seq))
# Provide a deterministic trace id
monkeypatch.setattr(
ext_request_logging,
"get_trace_id_from_otel_context",
lambda: "trace-xyz",
)
# Simulate request_started to record start timestamp on g
ext_request_logging._log_request_started(app)

# Capture logs from the real logger at INFO level only (skip DEBUG branch)
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
response = Response(json.dumps({"ok": True}), mimetype="application/json", status=200)
_log_request_finished(app, response)

# Verify a single INFO record with the five fields in order
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
assert len(info_records) == 1
msg = info_records[0].getMessage()
# Expected format: METHOD PATH STATUS DURATION_MS TRACE_ID
assert "GET" in msg
assert "/foo" in msg
assert "200" in msg
assert "123.456" in msg # rounded to 3 decimals
assert "trace-xyz" in msg

def test_info_access_log_uses_dash_without_start_timestamp(self, monkeypatch, caplog):
app = _get_test_app()
with app.test_request_context("/bar", method="POST"):
# No g.__request_started_ts set -> duration should be '-'
monkeypatch.setattr(
ext_request_logging,
"get_trace_id_from_otel_context",
lambda: "tid-no-start",
)
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
response = Response("OK", mimetype="text/plain", status=204)
_log_request_finished(app, response)

info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
assert len(info_records) == 1
msg = info_records[0].getMessage()
assert "POST" in msg
assert "/bar" in msg
assert "204" in msg
# Duration placeholder
# The fields are space separated; ensure a standalone '-' appears
assert " - " in msg or msg.endswith(" -")
assert "tid-no-start" in msg