From eaf86c521fd6085fbbe1cb0069adf82c87563841 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 9 Mar 2026 23:37:26 -0500 Subject: [PATCH] feat: Improve SQL Comment Context for Celery Worker Queries (#33058) --- api/extensions/otel/celery_sqlcommenter.py | 114 ++++++++++++ api/extensions/otel/runtime.py | 3 + .../otel/test_celery_sqlcommenter.py | 172 ++++++++++++++++++ 3 files changed, 289 insertions(+) create mode 100644 api/extensions/otel/celery_sqlcommenter.py create mode 100644 api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py diff --git a/api/extensions/otel/celery_sqlcommenter.py b/api/extensions/otel/celery_sqlcommenter.py new file mode 100644 index 0000000000..8abb1ce15a --- /dev/null +++ b/api/extensions/otel/celery_sqlcommenter.py @@ -0,0 +1,114 @@ +""" +Celery SQL comment context for OpenTelemetry SQLCommenter. + +Injects Celery-specific metadata (framework, task_name, traceparent, celery_retries, +routing_key) into SQL comments for queries executed by Celery workers. This improves +trace-to-SQL correlation and debugging in production. + +Uses the OpenTelemetry context key SQLCOMMENTER_ORM_TAGS_AND_VALUES, which is read +by opentelemetry.instrumentation.sqlcommenter_utils._add_framework_tags() when the +SQLAlchemy instrumentor appends comments to SQL statements. +""" + +import logging +from typing import Any + +from celery.signals import task_postrun, task_prerun +from opentelemetry import context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +logger = logging.getLogger(__name__) +_TRACE_PROPAGATOR = TraceContextTextMapPropagator() + +_SQLCOMMENTER_CONTEXT_KEY = "SQLCOMMENTER_ORM_TAGS_AND_VALUES" +_TOKEN_ATTR = "_dify_sqlcommenter_context_token" + + +def _build_celery_sqlcommenter_tags(task: Any) -> dict[str, str | int]: + """Build SQL commenter tags from the current Celery task and OpenTelemetry context.""" + tags: dict[str, str | int] = {} + + try: + tags["framework"] = f"celery:{_get_celery_version()}" + except Exception: + tags["framework"] = "celery:unknown" + + if task and getattr(task, "name", None): + tags["task_name"] = str(task.name) + + traceparent = _get_traceparent() + if traceparent: + tags["traceparent"] = traceparent + + if task and hasattr(task, "request"): + request = task.request + retries = getattr(request, "retries", None) + if retries is not None and retries > 0: + tags["celery_retries"] = int(retries) + + delivery_info = getattr(request, "delivery_info", None) or {} + if isinstance(delivery_info, dict): + routing_key = delivery_info.get("routing_key") + if routing_key: + tags["routing_key"] = str(routing_key) + + return tags + + +def _get_celery_version() -> str: + import celery + + return getattr(celery, "__version__", "unknown") + + +def _get_traceparent() -> str | None: + """Extract traceparent from the current OpenTelemetry context.""" + carrier: dict[str, str] = {} + _TRACE_PROPAGATOR.inject(carrier) + return carrier.get("traceparent") + + +def _on_task_prerun(*args: object, **kwargs: object) -> None: + task = kwargs.get("task") + if not task: + return + + tags = _build_celery_sqlcommenter_tags(task) + if not tags: + return + + current = context.get_current() + new_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, tags, current) + token = context.attach(new_ctx) + setattr(task, _TOKEN_ATTR, token) + + +def _on_task_postrun(*args: object, **kwargs: object) -> None: + task = kwargs.get("task") + if not task: + return + + token = getattr(task, _TOKEN_ATTR, None) + if token is None: + return + + try: + context.detach(token) + except Exception: + logger.debug("Failed to detach SQL commenter context", exc_info=True) + finally: + try: + delattr(task, _TOKEN_ATTR) + except AttributeError: + pass + + +def setup_celery_sqlcommenter() -> None: + """ + Connect Celery task_prerun and task_postrun handlers to inject SQL comment + context for worker queries. Call this from init_celery_worker after + CeleryInstrumentor().instrument() so our handlers run after the OTEL + instrumentor's and the trace context is already attached. + """ + task_prerun.connect(_on_task_prerun, weak=False) + task_postrun.connect(_on_task_postrun, weak=False) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index a7181d2683..a9ff0eed22 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -67,11 +67,14 @@ def init_celery_worker(*args, **kwargs): from opentelemetry.metrics import get_meter_provider from opentelemetry.trace import get_tracer_provider + from extensions.otel.celery_sqlcommenter import setup_celery_sqlcommenter + tracer_provider = get_tracer_provider() metric_provider = get_meter_provider() if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + setup_celery_sqlcommenter() def is_instrument_flag_enabled() -> bool: diff --git a/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py b/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py new file mode 100644 index 0000000000..7a537b0502 --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py @@ -0,0 +1,172 @@ +"""Tests for Celery SQL comment context injection.""" + +from unittest.mock import MagicMock, patch + +from opentelemetry import context + + +class TestBuildCelerySqlcommenterTags: + """Tests for _build_celery_sqlcommenter_tags.""" + + def test_includes_framework_and_task_name(self): + """Tags include celery framework version and task name.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.async_workflow_tasks.execute_workflow_team" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "framework" in tags + assert tags["framework"].startswith("celery:") + assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team" + + def test_includes_celery_retries_when_nonzero(self): + """celery_retries is included when retries > 0.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 3 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["celery_retries"] == 3 + + def test_omits_celery_retries_when_zero(self): + """celery_retries is omitted when retries is 0.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "celery_retries" not in tags + + def test_includes_routing_key_from_delivery_info(self): + """routing_key is included when present in delivery_info.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {"routing_key": "workflow_based_app_execution"} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["routing_key"] == "workflow_based_app_execution" + + def test_includes_traceparent_when_available(self): + """traceparent is included when injectable from current context.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + traceparent = "00-5db86c23fa8d05b67db315694b518684-737bbf30cdcda066-00" + with patch( + "extensions.otel.celery_sqlcommenter._get_traceparent", + return_value=traceparent, + ): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["traceparent"] == traceparent + + def test_handles_task_without_request(self): + """Gracefully handles task without request attribute.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + del task.request + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "framework" in tags + assert "task_name" in tags + + +class TestTaskPrerunPostrunHandlers: + """Tests for task_prerun and task_postrun signal handlers.""" + + def test_prerun_sets_context_postrun_detaches(self): + """task_prerun attaches SQLCOMMENTER context; task_postrun detaches it.""" + from extensions.otel.celery_sqlcommenter import ( + _SQLCOMMENTER_CONTEXT_KEY, + _TOKEN_ATTR, + _on_task_postrun, + _on_task_prerun, + ) + + clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None) + token = context.attach(clean_ctx) + try: + task = MagicMock() + task.name = "tasks.async_workflow_tasks.execute_workflow_team" + task.request = MagicMock() + task.request.retries = 1 + task.request.delivery_info = {"routing_key": "workflow_based_app_execution"} + + with patch( + "extensions.otel.celery_sqlcommenter._get_traceparent", + return_value="00-abc123-def456-00", + ): + _on_task_prerun(task=task) + + tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags is not None + assert tags["framework"].startswith("celery:") + assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team" + assert tags["celery_retries"] == 1 + assert tags["routing_key"] == "workflow_based_app_execution" + assert tags["traceparent"] == "00-abc123-def456-00" + assert hasattr(task, _TOKEN_ATTR) + + _on_task_postrun(task=task) + + tags_after = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags_after is None + assert not hasattr(task, _TOKEN_ATTR) + finally: + context.detach(token) + + def test_prerun_skips_when_no_task(self): + """prerun does nothing when task is missing from kwargs.""" + from extensions.otel.celery_sqlcommenter import ( + _SQLCOMMENTER_CONTEXT_KEY, + _on_task_prerun, + ) + + clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None) + token = context.attach(clean_ctx) + try: + _on_task_prerun() + tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags is None + finally: + context.detach(token) + + def test_postrun_skips_when_no_token(self): + """postrun does nothing when task has no token (e.g. prerun was skipped).""" + from extensions.otel.celery_sqlcommenter import _on_task_postrun + + task = MagicMock() + _on_task_postrun(task=task)