mirror of https://github.com/langgenius/dify.git
feat: Improve SQL Comment Context for Celery Worker Queries (#33058)
This commit is contained in:
parent
08b3bce53c
commit
eaf86c521f
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
Celery SQL comment context for OpenTelemetry SQLCommenter.
|
||||
|
||||
Injects Celery-specific metadata (framework, task_name, traceparent, celery_retries,
|
||||
routing_key) into SQL comments for queries executed by Celery workers. This improves
|
||||
trace-to-SQL correlation and debugging in production.
|
||||
|
||||
Uses the OpenTelemetry context key SQLCOMMENTER_ORM_TAGS_AND_VALUES, which is read
|
||||
by opentelemetry.instrumentation.sqlcommenter_utils._add_framework_tags() when the
|
||||
SQLAlchemy instrumentor appends comments to SQL statements.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from celery.signals import task_postrun, task_prerun
|
||||
from opentelemetry import context
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_TRACE_PROPAGATOR = TraceContextTextMapPropagator()
|
||||
|
||||
_SQLCOMMENTER_CONTEXT_KEY = "SQLCOMMENTER_ORM_TAGS_AND_VALUES"
|
||||
_TOKEN_ATTR = "_dify_sqlcommenter_context_token"
|
||||
|
||||
|
||||
def _build_celery_sqlcommenter_tags(task: Any) -> dict[str, str | int]:
|
||||
"""Build SQL commenter tags from the current Celery task and OpenTelemetry context."""
|
||||
tags: dict[str, str | int] = {}
|
||||
|
||||
try:
|
||||
tags["framework"] = f"celery:{_get_celery_version()}"
|
||||
except Exception:
|
||||
tags["framework"] = "celery:unknown"
|
||||
|
||||
if task and getattr(task, "name", None):
|
||||
tags["task_name"] = str(task.name)
|
||||
|
||||
traceparent = _get_traceparent()
|
||||
if traceparent:
|
||||
tags["traceparent"] = traceparent
|
||||
|
||||
if task and hasattr(task, "request"):
|
||||
request = task.request
|
||||
retries = getattr(request, "retries", None)
|
||||
if retries is not None and retries > 0:
|
||||
tags["celery_retries"] = int(retries)
|
||||
|
||||
delivery_info = getattr(request, "delivery_info", None) or {}
|
||||
if isinstance(delivery_info, dict):
|
||||
routing_key = delivery_info.get("routing_key")
|
||||
if routing_key:
|
||||
tags["routing_key"] = str(routing_key)
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def _get_celery_version() -> str:
|
||||
import celery
|
||||
|
||||
return getattr(celery, "__version__", "unknown")
|
||||
|
||||
|
||||
def _get_traceparent() -> str | None:
|
||||
"""Extract traceparent from the current OpenTelemetry context."""
|
||||
carrier: dict[str, str] = {}
|
||||
_TRACE_PROPAGATOR.inject(carrier)
|
||||
return carrier.get("traceparent")
|
||||
|
||||
|
||||
def _on_task_prerun(*args: object, **kwargs: object) -> None:
|
||||
task = kwargs.get("task")
|
||||
if not task:
|
||||
return
|
||||
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
if not tags:
|
||||
return
|
||||
|
||||
current = context.get_current()
|
||||
new_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, tags, current)
|
||||
token = context.attach(new_ctx)
|
||||
setattr(task, _TOKEN_ATTR, token)
|
||||
|
||||
|
||||
def _on_task_postrun(*args: object, **kwargs: object) -> None:
|
||||
task = kwargs.get("task")
|
||||
if not task:
|
||||
return
|
||||
|
||||
token = getattr(task, _TOKEN_ATTR, None)
|
||||
if token is None:
|
||||
return
|
||||
|
||||
try:
|
||||
context.detach(token)
|
||||
except Exception:
|
||||
logger.debug("Failed to detach SQL commenter context", exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
delattr(task, _TOKEN_ATTR)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def setup_celery_sqlcommenter() -> None:
|
||||
"""
|
||||
Connect Celery task_prerun and task_postrun handlers to inject SQL comment
|
||||
context for worker queries. Call this from init_celery_worker after
|
||||
CeleryInstrumentor().instrument() so our handlers run after the OTEL
|
||||
instrumentor's and the trace context is already attached.
|
||||
"""
|
||||
task_prerun.connect(_on_task_prerun, weak=False)
|
||||
task_postrun.connect(_on_task_postrun, weak=False)
|
||||
|
|
@ -67,11 +67,14 @@ def init_celery_worker(*args, **kwargs):
|
|||
from opentelemetry.metrics import get_meter_provider
|
||||
from opentelemetry.trace import get_tracer_provider
|
||||
|
||||
from extensions.otel.celery_sqlcommenter import setup_celery_sqlcommenter
|
||||
|
||||
tracer_provider = get_tracer_provider()
|
||||
metric_provider = get_meter_provider()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Initializing OpenTelemetry for Celery worker")
|
||||
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
|
||||
setup_celery_sqlcommenter()
|
||||
|
||||
|
||||
def is_instrument_flag_enabled() -> bool:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,172 @@
|
|||
"""Tests for Celery SQL comment context injection."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from opentelemetry import context
|
||||
|
||||
|
||||
class TestBuildCelerySqlcommenterTags:
|
||||
"""Tests for _build_celery_sqlcommenter_tags."""
|
||||
|
||||
def test_includes_framework_and_task_name(self):
|
||||
"""Tags include celery framework version and task name."""
|
||||
from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags
|
||||
|
||||
task = MagicMock()
|
||||
task.name = "tasks.async_workflow_tasks.execute_workflow_team"
|
||||
task.request = MagicMock()
|
||||
task.request.retries = 0
|
||||
task.request.delivery_info = {}
|
||||
|
||||
with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None):
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
|
||||
assert "framework" in tags
|
||||
assert tags["framework"].startswith("celery:")
|
||||
assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team"
|
||||
|
||||
def test_includes_celery_retries_when_nonzero(self):
|
||||
"""celery_retries is included when retries > 0."""
|
||||
from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags
|
||||
|
||||
task = MagicMock()
|
||||
task.name = "tasks.my_task"
|
||||
task.request = MagicMock()
|
||||
task.request.retries = 3
|
||||
task.request.delivery_info = {}
|
||||
|
||||
with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None):
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
|
||||
assert tags["celery_retries"] == 3
|
||||
|
||||
def test_omits_celery_retries_when_zero(self):
|
||||
"""celery_retries is omitted when retries is 0."""
|
||||
from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags
|
||||
|
||||
task = MagicMock()
|
||||
task.name = "tasks.my_task"
|
||||
task.request = MagicMock()
|
||||
task.request.retries = 0
|
||||
task.request.delivery_info = {}
|
||||
|
||||
with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None):
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
|
||||
assert "celery_retries" not in tags
|
||||
|
||||
def test_includes_routing_key_from_delivery_info(self):
|
||||
"""routing_key is included when present in delivery_info."""
|
||||
from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags
|
||||
|
||||
task = MagicMock()
|
||||
task.name = "tasks.my_task"
|
||||
task.request = MagicMock()
|
||||
task.request.retries = 0
|
||||
task.request.delivery_info = {"routing_key": "workflow_based_app_execution"}
|
||||
|
||||
with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None):
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
|
||||
assert tags["routing_key"] == "workflow_based_app_execution"
|
||||
|
||||
def test_includes_traceparent_when_available(self):
|
||||
"""traceparent is included when injectable from current context."""
|
||||
from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags
|
||||
|
||||
task = MagicMock()
|
||||
task.name = "tasks.my_task"
|
||||
task.request = MagicMock()
|
||||
task.request.retries = 0
|
||||
task.request.delivery_info = {}
|
||||
|
||||
traceparent = "00-5db86c23fa8d05b67db315694b518684-737bbf30cdcda066-00"
|
||||
with patch(
|
||||
"extensions.otel.celery_sqlcommenter._get_traceparent",
|
||||
return_value=traceparent,
|
||||
):
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
|
||||
assert tags["traceparent"] == traceparent
|
||||
|
||||
def test_handles_task_without_request(self):
|
||||
"""Gracefully handles task without request attribute."""
|
||||
from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags
|
||||
|
||||
task = MagicMock()
|
||||
task.name = "tasks.my_task"
|
||||
del task.request
|
||||
|
||||
with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None):
|
||||
tags = _build_celery_sqlcommenter_tags(task)
|
||||
|
||||
assert "framework" in tags
|
||||
assert "task_name" in tags
|
||||
|
||||
|
||||
class TestTaskPrerunPostrunHandlers:
|
||||
"""Tests for task_prerun and task_postrun signal handlers."""
|
||||
|
||||
def test_prerun_sets_context_postrun_detaches(self):
|
||||
"""task_prerun attaches SQLCOMMENTER context; task_postrun detaches it."""
|
||||
from extensions.otel.celery_sqlcommenter import (
|
||||
_SQLCOMMENTER_CONTEXT_KEY,
|
||||
_TOKEN_ATTR,
|
||||
_on_task_postrun,
|
||||
_on_task_prerun,
|
||||
)
|
||||
|
||||
clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None)
|
||||
token = context.attach(clean_ctx)
|
||||
try:
|
||||
task = MagicMock()
|
||||
task.name = "tasks.async_workflow_tasks.execute_workflow_team"
|
||||
task.request = MagicMock()
|
||||
task.request.retries = 1
|
||||
task.request.delivery_info = {"routing_key": "workflow_based_app_execution"}
|
||||
|
||||
with patch(
|
||||
"extensions.otel.celery_sqlcommenter._get_traceparent",
|
||||
return_value="00-abc123-def456-00",
|
||||
):
|
||||
_on_task_prerun(task=task)
|
||||
|
||||
tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY)
|
||||
assert tags is not None
|
||||
assert tags["framework"].startswith("celery:")
|
||||
assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team"
|
||||
assert tags["celery_retries"] == 1
|
||||
assert tags["routing_key"] == "workflow_based_app_execution"
|
||||
assert tags["traceparent"] == "00-abc123-def456-00"
|
||||
assert hasattr(task, _TOKEN_ATTR)
|
||||
|
||||
_on_task_postrun(task=task)
|
||||
|
||||
tags_after = context.get_value(_SQLCOMMENTER_CONTEXT_KEY)
|
||||
assert tags_after is None
|
||||
assert not hasattr(task, _TOKEN_ATTR)
|
||||
finally:
|
||||
context.detach(token)
|
||||
|
||||
def test_prerun_skips_when_no_task(self):
|
||||
"""prerun does nothing when task is missing from kwargs."""
|
||||
from extensions.otel.celery_sqlcommenter import (
|
||||
_SQLCOMMENTER_CONTEXT_KEY,
|
||||
_on_task_prerun,
|
||||
)
|
||||
|
||||
clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None)
|
||||
token = context.attach(clean_ctx)
|
||||
try:
|
||||
_on_task_prerun()
|
||||
tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY)
|
||||
assert tags is None
|
||||
finally:
|
||||
context.detach(token)
|
||||
|
||||
def test_postrun_skips_when_no_token(self):
|
||||
"""postrun does nothing when task has no token (e.g. prerun was skipped)."""
|
||||
from extensions.otel.celery_sqlcommenter import _on_task_postrun
|
||||
|
||||
task = MagicMock()
|
||||
_on_task_postrun(task=task)
|
||||
Loading…
Reference in New Issue