refactor: reuse redis connection instead of create new one (#32678)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2026-03-09 15:53:21 +08:00 committed by GitHub
parent cbb19cce39
commit 9970f4449a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1360 additions and 112 deletions

View File

@ -21,6 +21,10 @@ celery_redis = Redis(
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
)
logger = logging.getLogger(__name__)

View File

@ -3,6 +3,7 @@ import math
import time
from collections.abc import Iterable, Sequence
from celery import group
from sqlalchemy import ColumnElement, and_, func, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
enqueued: int = 0
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
if not is_locked:
continue
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
enqueued += 1
if not any(acquired):
continue
jobs = [
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
if is_locked
]
result = group(jobs).apply_async()
enqueued = len(jobs)
logger.info(
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
page + 1,
pages,
len(subscriptions),
sum(1 for x in acquired if x),
enqueued,
result,
)
logger.info("Trigger refresh scan done: due=%d", total_due)

View File

@ -1,6 +1,6 @@
import logging
from celery import group, shared_task
from celery import current_app, group, shared_task
from sqlalchemy import and_, select
from sqlalchemy.orm import Session, sessionmaker
@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
with session_factory() as session:
total_dispatched = 0
# Process in batches until we've handled all due schedules or hit the limit
while True:
due_schedules = _fetch_due_schedules(session)
if not due_schedules:
break
dispatched_count = _process_schedules(session, due_schedules)
total_dispatched += dispatched_count
with current_app.producer_or_acquire() as producer: # type: ignore
dispatched_count = _process_schedules(session, due_schedules, producer)
total_dispatched += dispatched_count
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if (
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
):
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
if total_dispatched > 0:
logger.info("Total processed: %d dispatched", total_dispatched)
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
return list(due_schedules)
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
if not schedules:
return 0
@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
if tasks_to_dispatch:
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
job.apply_async()
job.apply_async(producer=producer)
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))

View File

@ -1,9 +1,10 @@
import logging
import time
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Any, Protocol
import click
from celery import shared_task
from celery import current_app, shared_task
from configs import dify_config
from core.db.session_factory import session_factory
@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
class CeleryTaskLike(Protocol):
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
@shared_task(queue="dataset")
def document_indexing_task(dataset_id: str, document_ids: list):
"""
@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
) -> None:
try:
_document_indexing(dataset_id, document_ids)
except Exception:
@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
with current_app.producer_or_acquire() as producer: # type: ignore
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.apply_async(
kwargs={
"tenant_id": document_task.tenant_id,
"dataset_id": document_task.dataset_id,
"document_ids": document_task.document_ids,
},
producer=producer,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@ -3,12 +3,13 @@ import json
import logging
import time
import uuid
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any
import click
from celery import shared_task # type: ignore
from celery import group, shared_task
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
@ -27,6 +28,11 @@ from services.file_service import FileService
logger = logging.getLogger(__name__)
def chunked(iterable: Sequence, size: int):
it = iter(iterable)
return iter(lambda: list(islice(it, size)), [])
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
@ -83,16 +89,24 @@ def rag_pipeline_run_task(
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
for batch in chunked(next_file_ids, 100):
jobs = []
for next_file_id in batch:
tenant_isolated_task_queue.set_task_waiting_time()
file_id = (
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
)
jobs.append(
rag_pipeline_run_task.s(
rag_pipeline_invoke_entities_file_id=file_id,
tenant_id=tenant_id,
)
)
if jobs:
group(jobs).apply_async()
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once_with(
tenant_id=next_task["tenant_id"],
dataset_id=next_task["dataset_id"],
document_ids=next_task["document_ids"],
)
# apply_async is used by implementation; assert it was called once with expected kwargs
assert task_dispatch_spy.apply_async.call_count == 1
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
assert call_kwargs == {
"tenant_id": next_task["tenant_id"],
"dataset_id": next_task["dataset_id"],
"document_ids": next_task["document_ids"],
}
set_waiting_spy.assert_called_once()
delete_key_spy.assert_not_called()
@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_not_called()
task_dispatch_spy.apply_async.assert_not_called()
delete_key_spy.assert_called_once()
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once()
task_dispatch_spy.apply_async.assert_called_once()
def test_sessions_close_on_successful_indexing(
self,
@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == concurrency_limit
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
assert set_waiting_spy.call_count == concurrency_limit
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == 3
assert task_dispatch_spy.apply_async.call_count == 3
for index, expected_task in enumerate(ordered_tasks):
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
"""Skip limit checks when billing feature is disabled."""

View File

@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
assert mock_task_func.apply_async.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
calls = mock_task_func.apply_async.call_args_list
sent_kwargs = calls[0][1]["kwargs"]
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
mock_task_func.apply_async.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant_id,
"dataset_id": dataset_id,
"document_ids": ["waiting-doc-1"],
}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
mock_task_func.apply_async.assert_called_once()
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant1_id,
"dataset_id": dataset1_id,
"document_ids": ["tenant1-doc-1"],
}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)

View File

@ -1,6 +1,6 @@
import json
import uuid
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group, pull 1 task a time by default
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group.apply_async
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert call_kwargs.get("tenant_id") == tenant1.id
# Verify only tenant1's waiting task was processed (via group)
assert mock_group.return_value.apply_async.called
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert first_kwargs.get("tenant_id") == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)

View File

@ -105,18 +105,26 @@ def app_model(
class MockCeleryGroup:
"""Mock for celery group() function that collects dispatched tasks."""
"""Mock for celery group() function that collects dispatched tasks.
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
(e.g. producer) so production code can pass broker-related options without
breaking tests.
"""
def __init__(self) -> None:
self.collected: list[dict[str, Any]] = []
self._applied = False
self.last_apply_async_kwargs: dict[str, Any] | None = None
def __call__(self, items: Any) -> MockCeleryGroup:
self.collected = list(items)
return self
def apply_async(self) -> None:
def apply_async(self, **kwargs: Any) -> None:
# Accept arbitrary kwargs like producer to be compatible with Celery
self._applied = True
self.last_apply_async_kwargs = kwargs
@property
def applied(self) -> bool:

File diff suppressed because it is too large Load Diff