mirror of https://github.com/langgenius/dify.git
refactor: reuse redis connection instead of create new one (#32678)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
cbb19cce39
commit
9970f4449a
|
|
@ -21,6 +21,10 @@ celery_redis = Redis(
|
||||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||||
|
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
|
||||||
|
socket_timeout=5,
|
||||||
|
socket_connect_timeout=5,
|
||||||
|
health_check_interval=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
|
from celery import group
|
||||||
from sqlalchemy import ColumnElement, and_, func, or_, select
|
from sqlalchemy import ColumnElement, and_, func, or_, select
|
||||||
from sqlalchemy.engine.row import Row
|
from sqlalchemy.engine.row import Row
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
|
||||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||||
|
|
||||||
enqueued: int = 0
|
if not any(acquired):
|
||||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
continue
|
||||||
if not is_locked:
|
|
||||||
continue
|
jobs = [
|
||||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||||
enqueued += 1
|
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
|
||||||
|
if is_locked
|
||||||
|
]
|
||||||
|
result = group(jobs).apply_async()
|
||||||
|
enqueued = len(jobs)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
|
||||||
page + 1,
|
page + 1,
|
||||||
pages,
|
pages,
|
||||||
len(subscriptions),
|
len(subscriptions),
|
||||||
sum(1 for x in acquired if x),
|
sum(1 for x in acquired if x),
|
||||||
enqueued,
|
enqueued,
|
||||||
|
result,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from celery import group, shared_task
|
from celery import current_app, group, shared_task
|
||||||
from sqlalchemy import and_, select
|
from sqlalchemy import and_, select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
|
@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
|
||||||
with session_factory() as session:
|
with session_factory() as session:
|
||||||
total_dispatched = 0
|
total_dispatched = 0
|
||||||
|
|
||||||
# Process in batches until we've handled all due schedules or hit the limit
|
|
||||||
while True:
|
while True:
|
||||||
due_schedules = _fetch_due_schedules(session)
|
due_schedules = _fetch_due_schedules(session)
|
||||||
|
|
||||||
if not due_schedules:
|
if not due_schedules:
|
||||||
break
|
break
|
||||||
|
|
||||||
dispatched_count = _process_schedules(session, due_schedules)
|
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||||
total_dispatched += dispatched_count
|
dispatched_count = _process_schedules(session, due_schedules, producer)
|
||||||
|
total_dispatched += dispatched_count
|
||||||
|
|
||||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||||
|
|
||||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
|
||||||
if (
|
|
||||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
|
||||||
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
|
||||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||||
|
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
|
||||||
|
logger.warning(
|
||||||
|
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||||
|
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||||
|
)
|
||||||
|
break
|
||||||
if total_dispatched > 0:
|
if total_dispatched > 0:
|
||||||
logger.info("Total processed: %d dispatched", total_dispatched)
|
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
|
||||||
|
|
||||||
|
|
||||||
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||||
|
|
@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||||
return list(due_schedules)
|
return list(due_schedules)
|
||||||
|
|
||||||
|
|
||||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
|
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
|
||||||
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||||
if not schedules:
|
if not schedules:
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
|
||||||
|
|
||||||
if tasks_to_dispatch:
|
if tasks_to_dispatch:
|
||||||
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||||
job.apply_async()
|
job.apply_async(producer=producer)
|
||||||
|
|
||||||
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import current_app, shared_task
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
|
|
@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CeleryTaskLike(Protocol):
|
||||||
|
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="dataset")
|
@shared_task(queue="dataset")
|
||||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||||
"""
|
"""
|
||||||
|
|
@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||||
|
|
||||||
|
|
||||||
def _document_indexing_with_tenant_queue(
|
def _document_indexing_with_tenant_queue(
|
||||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
|
||||||
):
|
) -> None:
|
||||||
try:
|
try:
|
||||||
_document_indexing(dataset_id, document_ids)
|
_document_indexing(dataset_id, document_ids)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
|
||||||
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||||
|
|
||||||
if next_tasks:
|
if next_tasks:
|
||||||
for next_task in next_tasks:
|
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||||
document_task = DocumentTask(**next_task)
|
for next_task in next_tasks:
|
||||||
# Process the next waiting task
|
document_task = DocumentTask(**next_task)
|
||||||
# Keep the flag set to indicate a task is running
|
# Keep the flag set to indicate a task is running
|
||||||
tenant_isolated_task_queue.set_task_waiting_time()
|
tenant_isolated_task_queue.set_task_waiting_time()
|
||||||
task_func.delay( # type: ignore
|
task_func.apply_async(
|
||||||
tenant_id=document_task.tenant_id,
|
kwargs={
|
||||||
dataset_id=document_task.dataset_id,
|
"tenant_id": document_task.tenant_id,
|
||||||
document_ids=document_task.document_ids,
|
"dataset_id": document_task.dataset_id,
|
||||||
)
|
"document_ids": document_task.document_ids,
|
||||||
|
},
|
||||||
|
producer=producer,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No more waiting tasks, clear the flag
|
# No more waiting tasks, clear the flag
|
||||||
tenant_isolated_task_queue.delete_task_key()
|
tenant_isolated_task_queue.delete_task_key()
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,13 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from itertools import islice
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task # type: ignore
|
from celery import group, shared_task
|
||||||
from flask import current_app, g
|
from flask import current_app, g
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
|
@ -27,6 +28,11 @@ from services.file_service import FileService
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def chunked(iterable: Sequence, size: int):
|
||||||
|
it = iter(iterable)
|
||||||
|
return iter(lambda: list(islice(it, size)), [])
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="pipeline")
|
@shared_task(queue="pipeline")
|
||||||
def rag_pipeline_run_task(
|
def rag_pipeline_run_task(
|
||||||
rag_pipeline_invoke_entities_file_id: str,
|
rag_pipeline_invoke_entities_file_id: str,
|
||||||
|
|
@ -83,16 +89,24 @@ def rag_pipeline_run_task(
|
||||||
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||||
|
|
||||||
if next_file_ids:
|
if next_file_ids:
|
||||||
for next_file_id in next_file_ids:
|
for batch in chunked(next_file_ids, 100):
|
||||||
# Process the next waiting task
|
jobs = []
|
||||||
# Keep the flag set to indicate a task is running
|
for next_file_id in batch:
|
||||||
tenant_isolated_task_queue.set_task_waiting_time()
|
tenant_isolated_task_queue.set_task_waiting_time()
|
||||||
rag_pipeline_run_task.delay( # type: ignore
|
|
||||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
file_id = (
|
||||||
if isinstance(next_file_id, bytes)
|
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
|
||||||
else next_file_id,
|
)
|
||||||
tenant_id=tenant_id,
|
|
||||||
)
|
jobs.append(
|
||||||
|
rag_pipeline_run_task.s(
|
||||||
|
rag_pipeline_invoke_entities_file_id=file_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if jobs:
|
||||||
|
group(jobs).apply_async()
|
||||||
else:
|
else:
|
||||||
# No more waiting tasks, clear the flag
|
# No more waiting tasks, clear the flag
|
||||||
tenant_isolated_task_queue.delete_task_key()
|
tenant_isolated_task_queue.delete_task_key()
|
||||||
|
|
|
||||||
|
|
@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
|
||||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
task_dispatch_spy.delay.assert_called_once_with(
|
# apply_async is used by implementation; assert it was called once with expected kwargs
|
||||||
tenant_id=next_task["tenant_id"],
|
assert task_dispatch_spy.apply_async.call_count == 1
|
||||||
dataset_id=next_task["dataset_id"],
|
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
|
||||||
document_ids=next_task["document_ids"],
|
assert call_kwargs == {
|
||||||
)
|
"tenant_id": next_task["tenant_id"],
|
||||||
|
"dataset_id": next_task["dataset_id"],
|
||||||
|
"document_ids": next_task["document_ids"],
|
||||||
|
}
|
||||||
set_waiting_spy.assert_called_once()
|
set_waiting_spy.assert_called_once()
|
||||||
delete_key_spy.assert_not_called()
|
delete_key_spy.assert_not_called()
|
||||||
|
|
||||||
|
|
@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
task_dispatch_spy.delay.assert_not_called()
|
task_dispatch_spy.apply_async.assert_not_called()
|
||||||
delete_key_spy.assert_called_once()
|
delete_key_spy.assert_called_once()
|
||||||
|
|
||||||
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
|
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
|
||||||
|
|
@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
task_dispatch_spy.delay.assert_called_once()
|
task_dispatch_spy.apply_async.assert_called_once()
|
||||||
|
|
||||||
def test_sessions_close_on_successful_indexing(
|
def test_sessions_close_on_successful_indexing(
|
||||||
self,
|
self,
|
||||||
|
|
@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert task_dispatch_spy.delay.call_count == concurrency_limit
|
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
|
||||||
assert set_waiting_spy.call_count == concurrency_limit
|
assert set_waiting_spy.call_count == concurrency_limit
|
||||||
|
|
||||||
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
|
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
|
||||||
|
|
@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert task_dispatch_spy.delay.call_count == 3
|
assert task_dispatch_spy.apply_async.call_count == 3
|
||||||
for index, expected_task in enumerate(ordered_tasks):
|
for index, expected_task in enumerate(ordered_tasks):
|
||||||
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
|
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
|
||||||
|
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
|
||||||
|
|
||||||
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
|
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
|
||||||
"""Skip limit checks when billing feature is disabled."""
|
"""Skip limit checks when billing feature is disabled."""
|
||||||
|
|
|
||||||
|
|
@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
|
||||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
# Verify task function was called for each waiting task
|
# Verify task function was called for each waiting task
|
||||||
assert mock_task_func.delay.call_count == 1
|
assert mock_task_func.apply_async.call_count == 1
|
||||||
|
|
||||||
# Verify correct parameters for each call
|
# Verify correct parameters for each call
|
||||||
calls = mock_task_func.delay.call_args_list
|
calls = mock_task_func.apply_async.call_args_list
|
||||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
sent_kwargs = calls[0][1]["kwargs"]
|
||||||
|
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||||
|
|
||||||
# Verify queue is empty after processing (tasks were pulled)
|
# Verify queue is empty after processing (tasks were pulled)
|
||||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||||
|
|
@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
|
||||||
assert updated_document.processing_started_at is not None
|
assert updated_document.processing_started_at is not None
|
||||||
|
|
||||||
# Verify waiting task was still processed despite core processing error
|
# Verify waiting task was still processed despite core processing error
|
||||||
mock_task_func.delay.assert_called_once()
|
mock_task_func.apply_async.assert_called_once()
|
||||||
|
|
||||||
# Verify correct parameters for the call
|
# Verify correct parameters for the call
|
||||||
call = mock_task_func.delay.call_args
|
call = mock_task_func.apply_async.call_args
|
||||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
assert call[1]["kwargs"] == {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"document_ids": ["waiting-doc-1"],
|
||||||
|
}
|
||||||
|
|
||||||
# Verify queue is empty after processing (task was pulled)
|
# Verify queue is empty after processing (task was pulled)
|
||||||
remaining_tasks = queue.pull_tasks(count=10)
|
remaining_tasks = queue.pull_tasks(count=10)
|
||||||
|
|
@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
|
||||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
# Verify only tenant1's waiting task was processed
|
# Verify only tenant1's waiting task was processed
|
||||||
mock_task_func.delay.assert_called_once()
|
mock_task_func.apply_async.assert_called_once()
|
||||||
call = mock_task_func.delay.call_args
|
call = mock_task_func.apply_async.call_args
|
||||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
assert call[1]["kwargs"] == {
|
||||||
|
"tenant_id": tenant1_id,
|
||||||
|
"dataset_id": dataset1_id,
|
||||||
|
"document_ids": ["tenant1-doc-1"],
|
||||||
|
}
|
||||||
|
|
||||||
# Verify tenant1's queue is empty
|
# Verify tenant1's queue is empty
|
||||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
|
@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
|
||||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||||
|
|
||||||
# Mock the task function calls
|
# Mock the Celery group scheduling used by the implementation
|
||||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||||
|
mock_group.return_value.apply_async = MagicMock()
|
||||||
|
|
||||||
# Act: Execute the priority task with new code but legacy queue data
|
# Act: Execute the priority task with new code but legacy queue data
|
||||||
rag_pipeline_run_task(file_id, tenant.id)
|
rag_pipeline_run_task(file_id, tenant.id)
|
||||||
|
|
||||||
|
|
@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
|
||||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||||
assert mock_pipeline_generator.call_count == 1
|
assert mock_pipeline_generator.call_count == 1
|
||||||
|
|
||||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
# Verify waiting tasks were processed via group, pull 1 task a time by default
|
||||||
assert mock_delay.call_count == 1
|
assert mock_group.return_value.apply_async.called
|
||||||
|
|
||||||
# Verify correct parameters for the call
|
# Verify correct parameters for the first scheduled job signature
|
||||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||||
assert call_kwargs.get("tenant_id") == tenant.id
|
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||||
|
assert first_kwargs.get("tenant_id") == tenant.id
|
||||||
|
|
||||||
# Verify that new code can process legacy queue entries
|
# Verify that new code can process legacy queue entries
|
||||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||||
|
|
@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
|
||||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||||
queue.push_tasks(waiting_file_ids)
|
queue.push_tasks(waiting_file_ids)
|
||||||
|
|
||||||
# Mock the task function calls
|
# Mock the Celery group scheduling used by the implementation
|
||||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||||
|
mock_group.return_value.apply_async = MagicMock()
|
||||||
|
|
||||||
# Act: Execute the regular task
|
# Act: Execute the regular task
|
||||||
rag_pipeline_run_task(file_id, tenant.id)
|
rag_pipeline_run_task(file_id, tenant.id)
|
||||||
|
|
||||||
|
|
@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
|
||||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||||
assert mock_pipeline_generator.call_count == 1
|
assert mock_pipeline_generator.call_count == 1
|
||||||
|
|
||||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
# Verify waiting tasks were processed via group.apply_async
|
||||||
assert mock_delay.call_count == 1
|
assert mock_group.return_value.apply_async.called
|
||||||
|
|
||||||
# Verify correct parameters for the call
|
# Verify correct parameters for the first scheduled job signature
|
||||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||||
assert call_kwargs.get("tenant_id") == tenant.id
|
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||||
|
assert first_kwargs.get("tenant_id") == tenant.id
|
||||||
|
|
||||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||||
remaining_tasks = queue.pull_tasks(count=10)
|
remaining_tasks = queue.pull_tasks(count=10)
|
||||||
|
|
@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
|
||||||
waiting_file_id = str(uuid.uuid4())
|
waiting_file_id = str(uuid.uuid4())
|
||||||
queue.push_tasks([waiting_file_id])
|
queue.push_tasks([waiting_file_id])
|
||||||
|
|
||||||
# Mock the task function calls
|
# Mock the Celery group scheduling used by the implementation
|
||||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||||
|
mock_group.return_value.apply_async = MagicMock()
|
||||||
|
|
||||||
# Act: Execute the regular task (should not raise exception)
|
# Act: Execute the regular task (should not raise exception)
|
||||||
rag_pipeline_run_task(file_id, tenant.id)
|
rag_pipeline_run_task(file_id, tenant.id)
|
||||||
|
|
||||||
|
|
@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
|
||||||
assert mock_pipeline_generator.call_count == 1
|
assert mock_pipeline_generator.call_count == 1
|
||||||
|
|
||||||
# Verify waiting task was still processed despite core processing error
|
# Verify waiting task was still processed despite core processing error
|
||||||
mock_delay.assert_called_once()
|
assert mock_group.return_value.apply_async.called
|
||||||
|
|
||||||
# Verify correct parameters for the call
|
# Verify correct parameters for the first scheduled job signature
|
||||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||||
assert call_kwargs.get("tenant_id") == tenant.id
|
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||||
|
assert first_kwargs.get("tenant_id") == tenant.id
|
||||||
|
|
||||||
# Verify queue is empty after processing (task was pulled)
|
# Verify queue is empty after processing (task was pulled)
|
||||||
remaining_tasks = queue.pull_tasks(count=10)
|
remaining_tasks = queue.pull_tasks(count=10)
|
||||||
|
|
@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
|
||||||
queue1.push_tasks([waiting_file_id1])
|
queue1.push_tasks([waiting_file_id1])
|
||||||
queue2.push_tasks([waiting_file_id2])
|
queue2.push_tasks([waiting_file_id2])
|
||||||
|
|
||||||
# Mock the task function calls
|
# Mock the Celery group scheduling used by the implementation
|
||||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||||
|
mock_group.return_value.apply_async = MagicMock()
|
||||||
|
|
||||||
# Act: Execute the regular task for tenant1 only
|
# Act: Execute the regular task for tenant1 only
|
||||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||||
|
|
||||||
|
|
@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
|
||||||
assert mock_file_service["delete_file"].call_count == 1
|
assert mock_file_service["delete_file"].call_count == 1
|
||||||
assert mock_pipeline_generator.call_count == 1
|
assert mock_pipeline_generator.call_count == 1
|
||||||
|
|
||||||
# Verify only tenant1's waiting task was processed
|
# Verify only tenant1's waiting task was processed (via group)
|
||||||
mock_delay.assert_called_once()
|
assert mock_group.return_value.apply_async.called
|
||||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||||
|
assert first_kwargs.get("tenant_id") == tenant1.id
|
||||||
|
|
||||||
# Verify tenant1's queue is empty
|
# Verify tenant1's queue is empty
|
||||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||||
|
|
@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
|
||||||
waiting_file_id = str(uuid.uuid4())
|
waiting_file_id = str(uuid.uuid4())
|
||||||
queue.push_tasks([waiting_file_id])
|
queue.push_tasks([waiting_file_id])
|
||||||
|
|
||||||
# Mock the task function calls
|
# Mock the Celery group scheduling used by the implementation
|
||||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||||
|
mock_group.return_value.apply_async = MagicMock()
|
||||||
|
|
||||||
# Act & Assert: Execute the regular task (should raise Exception)
|
# Act & Assert: Execute the regular task (should raise Exception)
|
||||||
with pytest.raises(Exception, match="File not found"):
|
with pytest.raises(Exception, match="File not found"):
|
||||||
rag_pipeline_run_task(file_id, tenant.id)
|
rag_pipeline_run_task(file_id, tenant.id)
|
||||||
|
|
@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
|
||||||
mock_pipeline_generator.assert_not_called()
|
mock_pipeline_generator.assert_not_called()
|
||||||
|
|
||||||
# Verify waiting task was still processed despite file error
|
# Verify waiting task was still processed despite file error
|
||||||
mock_delay.assert_called_once()
|
assert mock_group.return_value.apply_async.called
|
||||||
|
|
||||||
# Verify correct parameters for the call
|
# Verify correct parameters for the first scheduled job signature
|
||||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||||
assert call_kwargs.get("tenant_id") == tenant.id
|
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||||
|
assert first_kwargs.get("tenant_id") == tenant.id
|
||||||
|
|
||||||
# Verify queue is empty after processing (task was pulled)
|
# Verify queue is empty after processing (task was pulled)
|
||||||
remaining_tasks = queue.pull_tasks(count=10)
|
remaining_tasks = queue.pull_tasks(count=10)
|
||||||
|
|
|
||||||
|
|
@ -105,18 +105,26 @@ def app_model(
|
||||||
|
|
||||||
|
|
||||||
class MockCeleryGroup:
|
class MockCeleryGroup:
|
||||||
"""Mock for celery group() function that collects dispatched tasks."""
|
"""Mock for celery group() function that collects dispatched tasks.
|
||||||
|
|
||||||
|
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
|
||||||
|
(e.g. producer) so production code can pass broker-related options without
|
||||||
|
breaking tests.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.collected: list[dict[str, Any]] = []
|
self.collected: list[dict[str, Any]] = []
|
||||||
self._applied = False
|
self._applied = False
|
||||||
|
self.last_apply_async_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||||
self.collected = list(items)
|
self.collected = list(items)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def apply_async(self) -> None:
|
def apply_async(self, **kwargs: Any) -> None:
|
||||||
|
# Accept arbitrary kwargs like producer to be compatible with Celery
|
||||||
self._applied = True
|
self._applied = True
|
||||||
|
self.last_apply_async_kwargs = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def applied(self) -> bool:
|
def applied(self) -> bool:
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue