From 20fc69ae7fe9f5d3b0d57ae231d4b0476f31411d Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:44:46 +0100 Subject: [PATCH 01/32] refactor: use EnumText for WorkflowAppLog.created_from and WorkflowArchiveLog columns (#33954) --- .../apps/workflow/generate_task_pipeline.py | 2 +- api/models/workflow.py | 12 ++++++-- api/tasks/trigger_processing_tasks.py | 2 +- ..._sqlalchemy_api_workflow_run_repository.py | 6 ++-- .../services/test_workflow_app_service.py | 29 ++++++++++--------- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445..bd6e2a0302 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): app_id=self._application_generate_entity.app_config.app_id, workflow_id=self._workflow.id, workflow_run_id=workflow_run_id, - created_from=created_from.value, + created_from=created_from, created_by_role=self._created_by_role, created_by=self._user_id, ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 6e8dda429d..334ec42058 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False + ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase): log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True + ) run_version: Mapped[str] = mapped_column(String(255), nullable=False) - run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), nullable=False + ) run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False ) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 75ae1f6316..f8c7964805 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -179,7 +179,7 @@ def _record_trigger_failure_log( app_id=workflow.app_id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=created_by_role, created_by=created_by, ) diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index c3ed79656f..49b370990a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -27,7 +27,7 @@ from models.human_input import ( HumanInputFormRecipient, RecipientType, ) -from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, @@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -278,7 +278,7 @@ class TestCountRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a..84ce6364df 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole +from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency @@ -221,7 +222,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -357,7 +358,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_1.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -399,7 +400,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_2.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -441,7 +442,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_4.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -521,7 +522,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -627,7 +628,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -732,7 +733,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -860,7 +861,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -902,7 +903,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="web-app", + created_from=WorkflowAppLogCreatedFrom.WEB_APP, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -1037,7 +1038,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1125,7 +1126,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1279,7 +1280,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1379,7 +1380,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1481,7 +1482,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) From 4a2e9633db85acde1a4b42cfd0dfad0672628d5f Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:46:06 +0100 Subject: [PATCH 02/32] refactor: use EnumText for ApiToken.type (#33961) --- api/controllers/console/apikey.py | 14 ++++++++------ api/controllers/console/datasets/datasets.py | 6 +++--- api/models/enums.py | 7 +++++++ api/models/model.py | 3 ++- .../libs/test_api_token_cache_integration.py | 3 ++- .../unit_tests/controllers/console/test_apikey.py | 5 +++-- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 6c54be84a8..783cb5c444 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset +from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache @@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None @@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource): ) key = ApiToken.generate_api_key(self.token_prefix or "", 24) + assert self.resource_type is not None, "resource_type must be set" api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_tenant_id @@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None @@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" token_prefix = "app-" @@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource): """Delete an API key for an app""" return super().delete(resource_id, api_key_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" @@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" token_prefix = "ds-" @@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 725a8380cd..fb98932269 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum -from models.enums import SegmentStatus +from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource): class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") @@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource): @console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("delete_dataset_api_key") @console_ns.doc(description="Delete dataset API key") diff --git a/api/models/enums.py b/api/models/enums.py index 4849099d30..8aca1df2b4 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -323,3 +323,10 @@ class ProviderQuotaType(StrEnum): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class ApiTokenType(StrEnum): + """API Token type""" + + APP = "app" + DATASET = "dataset" diff --git a/api/models/model.py b/api/models/model.py index b098966052..331a5b7d8c 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,6 +31,7 @@ from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db from .enums import ( + ApiTokenType, AppMCPServerStatus, AppStatus, BannerStatus, @@ -2095,7 +2096,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(sa.DateTime, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 1d7b835fd2..a942690cbd 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -13,6 +13,7 @@ from unittest.mock import patch import pytest from extensions.ext_redis import redis_client +from models.enums import ApiTokenType from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken @@ -279,7 +280,7 @@ class TestEndToEndCacheFlow: test_token = ApiToken() test_token.id = "test-e2e-id" test_token.token = test_token_value - test_token.type = test_scope + test_token.type = ApiTokenType.APP test_token.app_id = "test-app" test_token.tenant_id = "test-tenant" test_token.last_used_at = None diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index c18dd044a7..2dff9c4037 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -8,6 +8,7 @@ from controllers.console.apikey import ( BaseApiKeyResource, _get_resource, ) +from models.enums import ApiTokenType @pytest.fixture @@ -45,14 +46,14 @@ def bypass_permissions(): class DummyApiKeyListResource(BaseApiKeyListResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" token_prefix = "app-" class DummyApiKeyResource(BaseApiKeyResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" From 3f086b97b6a0aad1d78d367b901aa4e374f3feaf Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:46:54 -0500 Subject: [PATCH 03/32] test: remove mock tests superseded by testcontainers (#33957) --- .../test_document_service_display_status.py | 8 + .../test_document_service_display_status.py | 8 - .../services/test_web_conversation_service.py | 259 ------------------ 3 files changed, 8 insertions(+), 267 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_document_service_display_status.py delete mode 100644 api/tests/unit_tests/services/test_web_conversation_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index c6aa89c733..47d259d8a0 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -142,3 +142,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c rows = db_session_with_containers.scalars(filtered).all() assert {row.id for row in rows} == {doc1.id, doc2.id} + + +def test_normalize_display_status_alias_mapping(): + """Test that normalize_display_status maps aliases correctly.""" + assert DocumentService.normalize_display_status("ACTIVE") == "available" + assert DocumentService.normalize_display_status("enabled") == "available" + assert DocumentService.normalize_display_status("archived") == "archived" + assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_document_service_display_status.py b/api/tests/unit_tests/services/test_document_service_display_status.py deleted file mode 100644 index cb2e2940c8..0000000000 --- a/api/tests/unit_tests/services/test_document_service_display_status.py +++ /dev/null @@ -1,8 +0,0 @@ -from services.dataset_service import DocumentService - - -def test_normalize_display_status_alias_mapping(): - assert DocumentService.normalize_display_status("ACTIVE") == "available" - assert DocumentService.normalize_display_status("enabled") == "available" - assert DocumentService.normalize_display_status("archived") == "archived" - assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_web_conversation_service.py b/api/tests/unit_tests/services/test_web_conversation_service.py deleted file mode 100644 index 7687d355e9..0000000000 --- a/api/tests/unit_tests/services/test_web_conversation_service.py +++ /dev/null @@ -1,259 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account -from models.model import App, EndUser -from services.web_conversation_service import WebConversationService - - -@pytest.fixture -def app_model() -> App: - return cast(App, SimpleNamespace(id="app-1")) - - -def _account(**kwargs: Any) -> Account: - return cast(Account, SimpleNamespace(**kwargs)) - - -def _end_user(**kwargs: Any) -> EndUser: - return cast(EndUser, SimpleNamespace(**kwargs)) - - -def test_pagination_by_last_id_should_raise_error_when_user_is_none( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - - # Act + Assert - with pytest.raises(ValueError, match="User is required"): - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=None, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - -def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_user = _account(id="user-1") - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id="conv-9", - limit=10, - invoke_from=InvokeFrom.WEB_APP, - pinned=None, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] is None - assert call_kwargs["exclude_ids"] is None - assert call_kwargs["last_id"] == "conv-9" - assert call_kwargs["sort_by"] == "-updated_at" - - -def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-1" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {})) - session.scalars.return_value.all.return_value = ["conv-1", "conv-2"] - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - pinned=True, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] == ["conv-1", "conv-2"] - assert call_kwargs["exclude_ids"] is None - - -def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-1" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - session.scalars.return_value.all.return_value = ["conv-3"] - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - pinned=False, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] is None - assert call_kwargs["exclude_ids"] == ["conv-3"] - - -def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: - # Arrange - mock_db = mocker.patch("services.web_conversation_service.db") - mocker.patch("services.web_conversation_service.ConversationService.get_conversation") - - # Act - WebConversationService.pin(app_model, "conv-1", None) - - # Assert - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_pin_should_return_early_when_conversation_is_already_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-1" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = object() - mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation") - - # Act - WebConversationService.pin(app_model, "conv-1", fake_user) - - # Assert - mock_get_conversation.assert_not_called() - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_pin_should_create_pinned_conversation_when_not_already_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-2" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = None - mock_conversation = SimpleNamespace(id="conv-2") - mock_get_conversation = mocker.patch( - "services.web_conversation_service.ConversationService.get_conversation", - return_value=mock_conversation, - ) - - # Act - WebConversationService.pin(app_model, "conv-2", fake_user) - - # Assert - mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user) - added_obj = mock_db.session.add.call_args.args[0] - assert added_obj.app_id == "app-1" - assert added_obj.conversation_id == "conv-2" - assert added_obj.created_by_role == "account" - assert added_obj.created_by == "account-2" - mock_db.session.commit.assert_called_once() - - -def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: - # Arrange - mock_db = mocker.patch("services.web_conversation_service.db") - - # Act - WebConversationService.unpin(app_model, "conv-1", None) - - # Assert - mock_db.session.delete.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_unpin_should_return_early_when_conversation_is_not_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-3" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act - WebConversationService.unpin(app_model, "conv-7", fake_user) - - # Assert - mock_db.session.delete.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_unpin_should_delete_pinned_conversation_when_exists( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-4" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - pinned_obj = SimpleNamespace(id="pin-1") - mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj - - # Act - WebConversationService.unpin(app_model, "conv-8", fake_user) - - # Assert - mock_db.session.delete.assert_called_once_with(pinned_obj) - mock_db.session.commit.assert_called_once() From 8ca1ebb96d905c3ff00a916596669c9cd2768a6c Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:50:10 -0500 Subject: [PATCH 04/32] test: migrate workflow tools manage service tests to testcontainers (#33955) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../test_workflow_tools_manage_service.py | 109 ++ .../test_workflow_tools_manage_service.py | 955 ------------------ 2 files changed, 109 insertions(+), 955 deletions(-) delete mode 100644 api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e54..e3c0749494 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService: # After the fix, this should always be 0 # For now, we document that the record may exist, demonstrating the bug # assert tool_count == 0 # Expected after fix + + def test_delete_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of a workflow tool.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + tool_name = fake.unique.word() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + ) + + tool = ( + db_session_with_containers.query(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name) + .first() + ) + assert tool is not None + + result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first() + ) + assert deleted is None + + def test_list_tenant_workflow_tools_empty( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing workflow tools when none exist returns empty list.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id) + + assert result == [] + + def test_get_workflow_tool_by_tool_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_tool_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_get_workflow_tool_by_app_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_app_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_list_single_workflow_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that list_single_workflow_tools raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4()) + + def test_create_workflow_tool_with_labels( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that labels are forwarded to ToolLabelManager when provided.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.unique.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + labels=["label-1", "label-2"], + ) + + assert result == {"result": "success"} + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index e9bcc89445..0000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,955 +0,0 @@ -""" -Unit tests for services.tools.workflow_tools_manage_service - -Covers WorkflowToolManageService: create, update, list, delete, get, list_single. -""" - -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service -from services.tools.workflow_tools_manage_service import WorkflowToolManageService - -# --------------------------------------------------------------------------- -# Shared helpers / fake infrastructure -# --------------------------------------------------------------------------- - - -class DummyWorkflow: - """Minimal in-memory Workflow substitute.""" - - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - """Chainable query object that always returns a fixed result.""" - - def __init__(self, result: object) -> None: - self._result = result - - def where(self, *args: object, **kwargs: object) -> "FakeQuery": - return self - - def first(self) -> object: - return self._result - - def delete(self) -> int: - return 1 - - -class DummySession: - """Minimal SQLAlchemy session substitute.""" - - def __init__(self) -> None: - self.added: list[WorkflowToolProvider] = [] - self.committed: bool = False - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: - return False - - def add(self, obj: WorkflowToolProvider) -> None: - self.added.append(obj) - - def begin(self) -> "DummySession": - return self - - def commit(self) -> None: - self.committed = True - - -def _build_parameters() -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM), - ] - - -def _build_fake_db( - *, - existing_tool: WorkflowToolProvider | None = None, - app: object | None = None, - tool_by_id: WorkflowToolProvider | None = None, -) -> tuple[MagicMock, DummySession]: - """ - Build a fake db object plus a DummySession for Session context-manager. - - query(WorkflowToolProvider) returns existing_tool on first call, - then tool_by_id on subsequent calls (or None if not provided). - query(App) returns app. - """ - call_counts: dict[str, int] = {"wftp": 0} - - def query(model: type) -> FakeQuery: - if model is WorkflowToolProvider: - call_counts["wftp"] += 1 - if call_counts["wftp"] == 1: - return FakeQuery(existing_tool) - return FakeQuery(tool_by_id) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query, commit=MagicMock()) - dummy_session = DummySession() - return fake_db, dummy_session - - -# --------------------------------------------------------------------------- -# TestCreateWorkflowTool -# --------------------------------------------------------------------------- - - -class TestCreateWorkflowTool: - """Tests for WorkflowToolManageService.create_workflow_tool.""" - - def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Human-input nodes must be rejected before any provider is created.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app)) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - # Act + Assert - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "🔧"}, - description="desc", - parameters=_build_parameters(), - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - - def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Existing provider with same name or app_id raises ValueError.""" - # Arrange - existing = MagicMock(spec=WorkflowToolProvider) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(existing)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-1", - name="dup", - label="Dup", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the referenced App does not exist.""" - # Arrange - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(None) # App returns None - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="missing-app", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the App has no attached Workflow.""" - # Arrange - app_no_workflow = SimpleNamespace(workflow=None) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app_no_workflow) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr( - workflow_tools_manage_service.WorkflowToolProviderController, - "from_db", - MagicMock(side_effect=RuntimeError("bad config")), - ) - - # Act + Assert - with pytest.raises(ValueError, match="bad config"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: provider is added to session and success dict is returned.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0") - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - - icon = {"type": "emoji", "emoji": "🔧"} - - # Act - result = WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) - - # Assert - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created: WorkflowToolProvider = dummy_session.added[0] - assert created.name == "tool_name" - assert created.label == "Tool" - assert created.icon == json.dumps(icon) - assert created.version == "2.0.0" - - def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Labels are forwarded to ToolLabelManager when provided.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - mock_label_mgr = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) - mock_to_ctrl = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl - ) - - # Act - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - labels=["tag1", "tag2"], - ) - - # Assert - mock_label_mgr.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestUpdateWorkflowTool -# --------------------------------------------------------------------------- - - -class TestUpdateWorkflowTool: - """Tests for WorkflowToolManageService.update_workflow_tool.""" - - def _make_provider(self) -> WorkflowToolProvider: - p = MagicMock(spec=WorkflowToolProvider) - p.app_id = "app-id" - p.tenant_id = "tenant-id" - return p - - def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None: - """If another tool with the given name already exists, raise ValueError.""" - # Arrange - existing = MagicMock(spec=WorkflowToolProvider) - - def query(m: type) -> FakeQuery: - return FakeQuery(existing) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="dup", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the workflow tool to update does not exist.""" - # Arrange - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - # 1st call: name uniqueness check → None (no duplicate) - # 2nd call: fetch tool by id → None (not found) - return FakeQuery(None) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="missing", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the tool's referenced App has been removed.""" - # Arrange - provider = self._make_provider() - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - # 1st: duplicate name check (None), 2nd: fetch provider - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(None) # App not found - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the App exists but has no Workflow.""" - # Arrange - provider = self._make_provider() - app_no_wf = SimpleNamespace(workflow=None) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app_no_wf) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Exceptions from from_db are re-raised as ValueError.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=MagicMock()), - ) - monkeypatch.setattr( - workflow_tools_manage_service.WorkflowToolProviderController, - "from_db", - MagicMock(side_effect=RuntimeError("from_db error")), - ) - - # Act + Assert - with pytest.raises(ValueError, match="from_db error"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: provider fields are updated and session committed.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0") - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - mock_commit = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=mock_commit), - ) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - - icon = {"type": "emoji", "emoji": "🛠"} - - # Act - result = WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="new_name", - label="New Label", - icon=icon, - description="new desc", - parameters=_build_parameters(), - ) - - # Assert - assert result == {"result": "success"} - mock_commit.assert_called_once() - assert provider.name == "new_name" - assert provider.label == "New Label" - assert provider.icon == json.dumps(icon) - assert provider.version == "3.0.0" - - def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Labels are forwarded to ToolLabelManager during update.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=MagicMock()), - ) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - mock_label_mgr = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock() - ) - - # Act - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - labels=["a"], - ) - - # Assert - mock_label_mgr.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestListTenantWorkflowTools -# --------------------------------------------------------------------------- - - -class TestListTenantWorkflowTools: - """Tests for WorkflowToolManageService.list_tenant_workflow_tools.""" - - def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """An empty database yields an empty result list.""" - # Arrange - fake_scalars = MagicMock() - fake_scalars.all.return_value = [] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - assert result == [] - - def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Providers that fail to load are logged and skipped.""" - # Arrange - good_provider = MagicMock(spec=WorkflowToolProvider) - good_provider.id = "good-id" - good_provider.app_id = "app-good" - bad_provider = MagicMock(spec=WorkflowToolProvider) - bad_provider.id = "bad-id" - bad_provider.app_id = "app-bad" - - fake_scalars = MagicMock() - fake_scalars.all.return_value = [good_provider, bad_provider] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - good_ctrl = MagicMock() - good_ctrl.provider_id = "good-id" - - def to_controller(provider: WorkflowToolProvider) -> MagicMock: - if provider is bad_provider: - raise RuntimeError("broken provider") - return good_ctrl - - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller - ) - mock_get_labels = MagicMock(return_value={}) - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels) - mock_to_user = MagicMock() - mock_to_user.return_value.tools = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user - ) - monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) - mock_get_tools = MagicMock(return_value=[MagicMock()]) - good_ctrl.get_tools = mock_get_tools - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() - ) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - only good provider contributed - assert len(result) == 1 - - def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None: - """All successfully loaded providers appear in the result.""" - # Arrange - provider = MagicMock(spec=WorkflowToolProvider) - provider.id = "p-1" - provider.app_id = "app-1" - - fake_scalars = MagicMock() - fake_scalars.all.return_value = [provider] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - ctrl = MagicMock() - ctrl.provider_id = "p-1" - ctrl.get_tools.return_value = [MagicMock()] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []}) - ) - user_provider = MagicMock() - user_provider.tools = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_user_provider", - MagicMock(return_value=user_provider), - ) - monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() - ) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - assert len(result) == 1 - assert result[0] is user_provider - - -# --------------------------------------------------------------------------- -# TestDeleteWorkflowTool -# --------------------------------------------------------------------------- - - -class TestDeleteWorkflowTool: - """Tests for WorkflowToolManageService.delete_workflow_tool.""" - - def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: - """delete_workflow_tool queries, deletes, commits, and returns success.""" - # Arrange - mock_query = MagicMock() - mock_query.where.return_value.delete.return_value = 1 - mock_commit = MagicMock() - fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - # Act - result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1") - - # Assert - assert result == {"result": "success"} - mock_commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestGetWorkflowToolByToolId / ByAppId -# --------------------------------------------------------------------------- - - -class TestGetWorkflowToolByToolIdAndAppId: - """Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id.""" - - def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Raises ValueError when no WorkflowToolProvider found by tool id.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing") - - def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Raises ValueError when no WorkflowToolProvider found by app id.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app") - - -# --------------------------------------------------------------------------- -# TestGetWorkflowTool (private _get_workflow_tool) -# --------------------------------------------------------------------------- - - -class TestGetWorkflowTool: - """Tests for the internal _get_workflow_tool helper.""" - - def test_should_raise_when_db_tool_none(self) -> None: - """_get_workflow_tool raises ValueError when db_tool is None.""" - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService._get_workflow_tool("t", None) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the corresponding App row is missing.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when App has no attached Workflow.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - app = SimpleNamespace(workflow=None) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the controller returns no WorkflowTool instances.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - db_tool.id = "tool-1" - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - ctrl = MagicMock() - ctrl.get_tools.return_value = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: returns a dict with name, label, icon, synced, etc.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - db_tool.id = "tool-1" - db_tool.name = "my_tool" - db_tool.label = "My Tool" - db_tool.icon = json.dumps({"emoji": "🔧"}) - db_tool.description = "some desc" - db_tool.privacy_policy = "" - db_tool.version = "1.0" - db_tool.parameter_configurations = [] - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0") - app = SimpleNamespace(workflow=workflow) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - - workflow_tool = MagicMock() - workflow_tool.entity.output_schema = {"type": "object"} - ctrl = MagicMock() - ctrl.get_tools.return_value = [workflow_tool] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - mock_convert = MagicMock(return_value={"tool": "api_entity"}) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) - ) - - # Act - result = WorkflowToolManageService._get_workflow_tool("t", db_tool) - - # Assert - assert result["name"] == "my_tool" - assert result["label"] == "My Tool" - assert result["synced"] is True - assert "icon" in result - assert "output_schema" in result - - -# --------------------------------------------------------------------------- -# TestListSingleWorkflowTools -# --------------------------------------------------------------------------- - - -class TestListSingleWorkflowTools: - """Tests for WorkflowToolManageService.list_single_workflow_tools.""" - - def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the specified tool does not exist in DB.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the controller yields no tools for the provider.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.id = "tool-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(db_tool)), - ) - ctrl = MagicMock() - ctrl.get_tools.return_value = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: returns list with one ToolApiEntity.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.id = "tool-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(db_tool)), - ) - workflow_tool = MagicMock() - ctrl = MagicMock() - ctrl.get_tools.return_value = [workflow_tool] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - api_entity = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "convert_tool_entity_to_api_entity", - MagicMock(return_value=api_entity), - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) - ) - - # Act - result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - # Assert - assert result == [api_entity] From 75c3ef82d99b71ad30a5eed2fb6182185c229dc6 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:51:10 +0100 Subject: [PATCH 05/32] refactor: use EnumText for TenantCreditPool.pool_type (#33959) --- api/core/provider_manager.py | 4 ++-- api/models/model.py | 5 ++++- api/services/credit_pool_service.py | 6 +++++- .../services/test_credit_pool_service.py | 9 +++++---- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd2..6d2be0ab7a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -918,11 +918,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/models/model.py b/api/models/model.py index 331a5b7d8c..4541a3b23a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -44,6 +44,7 @@ from .enums import ( MessageChainType, MessageFileBelongsTo, MessageStatus, + ProviderQuotaType, TagType, ) from .provider_ids import GenericProviderID @@ -2491,7 +2492,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571..2894826935 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 25de0588fa..0f63d98642 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -6,6 +6,7 @@ import pytest from core.errors.error import QuotaExceededError from models import TenantCreditPool +from models.enums import ProviderQuotaType from services.credit_pool_service import CreditPoolService @@ -20,7 +21,7 @@ class TestCreditPoolService: assert isinstance(pool, TenantCreditPool) assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" + assert pool.pool_type == ProviderQuotaType.TRIAL assert pool.quota_used == 0 assert pool.quota_limit > 0 @@ -28,14 +29,14 @@ class TestCreditPoolService: tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) assert result is not None assert result.tenant_id == tenant_id - assert result.pool_type == "trial" + assert result.pool_type == ProviderQuotaType.TRIAL def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): - result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None From dd4f504b39d1afbfeaf51199789da779e006f654 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:53:05 +0100 Subject: [PATCH 06/32] refactor: select in remaining console app controllers (#33969) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/conversation.py | 10 +++------- api/controllers/console/app/generator.py | 2 +- api/controllers/console/app/mcp_server.py | 14 +++++++------- api/controllers/console/app/model_config.py | 4 +--- api/controllers/console/app/site.py | 5 +++-- api/controllers/console/app/wraps.py | 10 +++++----- .../controllers/console/app/test_app_apis.py | 8 ++------ .../console/app/test_conversation_api.py | 12 ++---------- .../app/test_conversation_read_timestamp.py | 2 +- .../controllers/console/app/test_generator_api.py | 12 ++++-------- .../console/app/test_model_config_api.py | 5 +---- .../controllers/console/app/test_wraps.py | 8 ++------ 12 files changed, 32 insertions(+), 60 deletions(-) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 74750981dd..d329d22309 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -458,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bb..442d0d2324 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 4b20418b53..412fc8795a 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,7 +99,7 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb51..e9bd30ba7e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b81..7f44a99ff1 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa..493022ffea 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fe..beb8ff55a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -281,12 +281,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +303,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c332..11b3b3470d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc..f588ab261d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da3..e64c508b82 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c7..a0e2edb8cf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492da..b5f751f5a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): From 0492ed703457f186b5b7d29d4d8e813d088539c4 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:54:33 -0500 Subject: [PATCH 07/32] test: migrate api tools manage service tests to testcontainers (#33956) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../tools/test_api_tools_manage_service.py | 148 ++++ .../tools/test_api_tools_manage_service.py | 643 ------------------ 2 files changed, 148 insertions(+), 643 deletions(-) delete mode 100644 api/tests/unit_tests/services/tools/test_api_tools_manage_service.py diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a..d3e765055a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py deleted file mode 100644 index ce44818886..0000000000 --- a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py +++ /dev/null @@ -1,643 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.tools.entities.tool_entities import ApiProviderSchemaType -from services.tools.api_tools_manage_service import ApiToolManageService - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.tools.api_tools_manage_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace: - return SimpleNamespace(operation_id=operation_id) - - -def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value), - ) - - # Act - result = ApiToolManageService.parser_api_schema("valid-schema") - - # Assert - assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value - assert len(result["credentials_schema"]) == 3 - assert "warning" in result - - -def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("bad schema"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"): - ApiToolManageService.parser_api_schema("invalid") - - -def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None: - # Arrange - expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER) - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=expected, - ) - extra_info: dict[str, str] = {} - - # Act - result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info) - - # Assert - assert result == expected - - -def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=ValueError("parse failed"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: parse failed"): - ApiToolManageService.convert_schema_to_tool_bundles("schema") - - -def test_create_api_tool_provider_should_raise_error_when_provider_already_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="provider provider-a already exists"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name=" provider-a ", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - many_tools = [_tool_bundle(str(i)) for i in range(101)] - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=(many_tools, ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="the number of apis should be less than 100"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_create_provider_when_input_is_valid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - mock_controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=mock_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.encrypt.return_value = {"auth_type": "none"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - mock_controller.load_bundled_tools.assert_called_once() - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=200, text="schema-content"), - ) - mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True}) - - # Act - result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - - # Assert - assert result == {"schema": "schema-content"} - - -@pytest.mark.parametrize("status_code", [400, 404, 500]) -def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid( - status_code: int, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=status_code, text="schema-content"), - ) - mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger") - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema, please check the url you provided"): - ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - mock_logger.exception.assert_called_once() - - -def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - -def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")]) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - return_value=controller, - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"]) - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}], - ) - - # Act - result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - # Assert - assert result == [{"name": "tool-a"}, {"name": "tool-b"}] - assert mock_convert.call_count == 2 - - -def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="api provider provider-a does not exists"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={"auth_type": "none"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(credentials={}, name="old") - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace( - credentials={"auth_type": "none", "api_key_value": "encrypted-old"}, - name="old", - icon="", - schema="", - description="", - schema_type_str="", - tools_str="", - privacy_policy="", - custom_disclaimer="", - credentials_str="", - ) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=controller, - ) - cache = MagicMock() - encrypter = MagicMock() - encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"} - encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} - encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(encrypter, cache), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-new", - original_provider="provider-old", - icon={"emoji": "E"}, - credentials={"auth_type": "none", "api_key_value": "***"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - assert provider.name == "provider-new" - assert provider.privacy_policy == "privacy" - assert provider.credentials_str != "" - cache.delete.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - -def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None: - # Arrange - provider = object() - mock_db.session.query.return_value.where.return_value.first.return_value = provider - - # Act - result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == {"result": "success"} - mock_db.session.delete.assert_called_once_with(provider) - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None: - # Arrange - expected = {"provider": "value"} - mock_get = mocker.patch( - "services.tools.api_tools_manage_service.ToolManager.user_get_api_provider", - return_value=expected, - ) - - # Act - result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == expected - mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1") - - -def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None: - # Arrange - schema_type = "bad-schema-type" - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema type"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=schema_type, # type: ignore[arg-type] - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("invalid"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="invalid tool name tool-b"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-b", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.side_effect = ValueError("validation failed") - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"error": "validation failed"} - - -def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.return_value = {"ok": True} - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={"x": "1"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"result": {"ok": True}} - - -def test_list_api_tools_should_return_all_user_providers_with_converted_tools( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider_one = SimpleNamespace(name="p1") - provider_two = SimpleNamespace(name="p2") - mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two] - - controller_one = MagicMock() - controller_one.get_tools.return_value = ["tool-a"] - controller_two = MagicMock() - controller_two.get_tools.return_value = ["tool-b", "tool-c"] - - user_provider_one = SimpleNamespace(labels=[], tools=[]) - user_provider_two = SimpleNamespace(labels=[], tools=[]) - - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - side_effect=[controller_one, controller_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"]) - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider", - side_effect=[user_provider_one, user_provider_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider") - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}], - ) - - # Act - result = ApiToolManageService.list_api_tools("tenant-1") - - # Assert - assert len(result) == 2 - assert user_provider_one.tools == [{"name": "tool-a"}] - assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}] - assert mock_convert.call_count == 3 From f2c71f3668227097f8d3770d1c28020bf17d51a5 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 13:15:22 -0500 Subject: [PATCH 08/32] test: migrate oauth server service tests to testcontainers (#33958) --- .../services/test_oauth_server_service.py | 174 ++++++++++++++ .../services/test_oauth_server_service.py | 224 ------------------ 2 files changed, 174 insertions(+), 224 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_oauth_server_service.py delete mode 100644 api/tests/unit_tests/services/test_oauth_server_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..c146a5924b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/unit_tests/services/test_oauth_server_service.py b/api/tests/unit_tests/services/test_oauth_server_service.py deleted file mode 100644 index 231ceb74dc..0000000000 --- a/api/tests/unit_tests/services/test_oauth_server_service.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import BadRequest - -from services.oauth_server import ( - OAUTH_ACCESS_TOKEN_EXPIRES_IN, - OAUTH_ACCESS_TOKEN_REDIS_KEY, - OAUTH_AUTHORIZATION_CODE_REDIS_KEY, - OAUTH_REFRESH_TOKEN_EXPIRES_IN, - OAUTH_REFRESH_TOKEN_REDIS_KEY, - OAuthGrantType, - OAuthServerService, -) - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.oauth_server.redis_client") - - -@pytest.fixture -def mock_session(mocker: MockerFixture) -> MagicMock: - """Mock the OAuth server Session context manager.""" - mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object())) - session = MagicMock() - session_cm = MagicMock() - session_cm.__enter__.return_value = session - mocker.patch("services.oauth_server.Session", return_value=session_cm) - return session - - -def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None: - # Arrange - mock_execute_result = MagicMock() - expected_app = MagicMock() - mock_execute_result.scalar_one_or_none.return_value = expected_app - mock_session.execute.return_value = mock_execute_result - - # Act - result = OAuthServerService.get_oauth_provider_app("client-1") - - # Assert - assert result is expected_app - mock_session.execute.assert_called_once() - mock_execute_result.scalar_one_or_none.assert_called_once() - - -def test_sign_oauth_authorization_code_should_store_code_and_return_value( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") - - # Assert - expected_code = str(deterministic_uuid) - assert code == expected_code - mock_redis_client.set.assert_called_once_with( - OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code), - "user-1", - ex=600, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid code"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="bad-code", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - token_uuids = [ - uuid.UUID("00000000-0000-0000-0000-000000000201"), - uuid.UUID("00000000-0000-0000-0000-000000000202"), - ] - mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids) - mock_redis_client.get.return_value = b"user-1" - code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") - - # Act - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="code-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(token_uuids[0]) - assert refresh_token == str(token_uuids[1]) - mock_redis_client.delete.assert_called_once_with(code_key) - mock_redis_client.set.assert_any_call( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - mock_redis_client.set.assert_any_call( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), - b"user-1", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid refresh token"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="stale-token", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - mock_redis_client.get.return_value = b"user-1" - - # Act - access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="refresh-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(deterministic_uuid) - assert returned_refresh_token == "refresh-1" - mock_redis_client.set.assert_called_once_with( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None: - # Arrange - grant_type = cast(OAuthGrantType, "invalid-grant-type") - - # Act - result = OAuthServerService.sign_oauth_access_token( - grant_type=grant_type, - client_id="client-1", - ) - - # Assert - assert result is None - - -def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") - - # Assert - assert refresh_token == str(deterministic_uuid) - mock_redis_client.set.assert_called_once_with( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), - "user-2", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_validate_oauth_access_token_should_return_none_when_token_not_found( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") - - # Assert - assert result is None - - -def test_validate_oauth_access_token_should_load_user_when_token_exists( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - mock_redis_client.get.return_value = b"user-88" - expected_user = MagicMock() - mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user) - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") - - # Assert - assert result is expected_user - mock_load_user.assert_called_once_with("user-88") From 5d2cb3cd803c0ef152c4c90a1e2752297538f175 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:37:51 +0100 Subject: [PATCH 09/32] refactor: use EnumText for DocumentSegment.type (#33979) --- api/models/dataset.py | 5 ++++- api/models/enums.py | 7 +++++++ api/services/dataset_service.py | 5 +++-- api/tests/unit_tests/services/segment_service.py | 3 ++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e6984..e3cbbf9cb9 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -43,6 +43,7 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, ) from .model import App, Tag, TagBinding, UploadFile @@ -998,7 +999,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) diff --git a/api/models/enums.py b/api/models/enums.py index 8aca1df2b4..cdec7b2f12 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..ba4ab6757f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -3786,7 +3787,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3846,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b5..cc2c0a8032 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -4,6 +4,7 @@ import pytest from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import SegmentType from services.dataset_service import SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -77,7 +78,7 @@ class SegmentTestDataFactory: chunk.word_count = word_count chunk.index_node_id = f"node-{chunk_id}" chunk.index_node_hash = "hash-123" - chunk.type = "automatic" + chunk.type = SegmentType.AUTOMATIC chunk.created_by = "user-123" chunk.updated_by = None chunk.updated_at = None From cc17c8e883accea6a07a8be059b61da0fa4e2455 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:38:29 +0100 Subject: [PATCH 10/32] refactor: use EnumText for TidbAuthBinding.status and MessageFile.type (#33975) --- .../vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py | 3 ++- .../rag/datasource/vdb/tidb_on_qdrant/tidb_service.py | 3 ++- api/models/dataset.py | 5 ++++- api/models/model.py | 4 ++-- api/schedule/create_tidb_serverless_task.py | 3 ++- api/schedule/update_tidb_serverless_status_task.py | 6 +++++- .../services/test_messages_clean_service.py | 3 ++- .../app/task_pipeline/test_easy_ui_message_end_files.py | 8 ++++---- 8 files changed, 23 insertions(+), 12 deletions(-) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 71b6fa0a9b..3c1d5e015f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c149241..06b17b9e62 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/models/dataset.py b/api/models/dataset.py index e3cbbf9cb9..4c6152ed3f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -45,6 +45,7 @@ from .enums import ( SegmentStatus, SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -1242,7 +1243,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/model.py b/api/models/model.py index 4541a3b23a..05233f8711 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -21,7 +21,7 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] @@ -1785,7 +1785,7 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8b9d973d6d..6ceb3ef856 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -8,6 +8,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -57,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5..10003b1b97 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 8707f2e827..57bbc73b50 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from dify_graph.file.enums import FileType from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", belongs_to=MessageFileBelongsTo.USER, diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a..37dd116470 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod +from dify_graph.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture From 49a1fae55561d8924df19f7085bef4a5a944264d Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 16:04:34 -0500 Subject: [PATCH 11/32] test: migrate password reset tests to testcontainers (#33974) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/auth/test_password_reset.py | 109 +++--------------- 1 file changed, 17 insertions(+), 92 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/auth/test_password_reset.py (81%) diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528e..8f9db287e3 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None From 075b8bf1aeac34195c9e72049b04cb80613a59bc Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:04:08 +0800 Subject: [PATCH 12/32] fix(web): update account settings header (#33965) --- .../account-setting/__tests__/index.spec.tsx | 4 +-- .../header/account-setting/index.tsx | 31 +++++++------------ 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/web/app/components/header/account-setting/__tests__/index.spec.tsx b/web/app/components/header/account-setting/__tests__/index.spec.tsx index 2aa9db4771..279af0b114 100644 --- a/web/app/components/header/account-setting/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/index.spec.tsx @@ -315,14 +315,14 @@ describe('AccountSetting', () => { it('should handle scroll event in panel', () => { // Act renderAccountSetting() - const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto') + const scrollContainer = screen.getByRole('dialog').querySelector('.overscroll-contain') // Assert expect(scrollContainer).toBeInTheDocument() if (scrollContainer) { // Scroll down fireEvent.scroll(scrollContainer, { target: { scrollTop: 100 } }) - expect(scrollContainer).toHaveClass('overflow-y-auto') + expect(scrollContainer).toHaveClass('overscroll-contain') // Scroll back up fireEvent.scroll(scrollContainer, { target: { scrollTop: 0 } }) diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index 7e77af2e5f..bfceaeb059 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -1,8 +1,9 @@ 'use client' import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' -import { useCallback, useEffect, useRef, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import SearchInput from '@/app/components/base/search-input' +import { ScrollArea } from '@/app/components/base/ui/scroll-area' import BillingPage from '@/app/components/billing/billing-page' import CustomPage from '@/app/components/custom/custom-page' import { @@ -129,20 +130,6 @@ export default function AccountSetting({ ], }, ] - const scrollRef = useRef(null) - const [scrolled, setScrolled] = useState(false) - useEffect(() => { - const targetElement = scrollRef.current - const scrollHandle = (e: Event) => { - const userScrolled = (e.target as HTMLDivElement).scrollTop > 0 - setScrolled(userScrolled) - } - targetElement?.addEventListener('scroll', scrollHandle) - return () => { - targetElement?.removeEventListener('scroll', scrollHandle) - } - }, []) - const activeItem = [...menuItems[0].items, ...menuItems[1].items].find(item => item.key === activeMenu) const [searchValue, setSearchValue] = useState('') @@ -201,7 +188,7 @@ export default function AccountSetting({ } -
+
-
-
+ +
{activeItem?.name} {activeItem?.description && ( @@ -241,7 +234,7 @@ export default function AccountSetting({ {activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && } {activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && }
-
+
From fbd558762dc7dc0304c4e9efaa169505732c6098 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 24 Mar 2026 10:36:48 +0800 Subject: [PATCH 13/32] fix: fix chunk not display in indexed document (#33942) --- .../__tests__/document-settings.spec.tsx | 59 +++++++++++++++++++ .../detail/settings/document-settings.tsx | 26 +++++++- web/models/datasets.ts | 6 +- web/service/knowledge/use-create-dataset.ts | 6 +- 4 files changed, 92 insertions(+), 5 deletions(-) diff --git a/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx b/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx index 4ac30289e1..bf516d432b 100644 --- a/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx +++ b/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx @@ -224,6 +224,20 @@ describe('DocumentSettings', () => { // Data source types describe('Data Source Types', () => { + it('should handle upload_file_id data source format', () => { + mockDocumentDetail = { + name: 'test-document', + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: '4a807f05-45d6-4fc4-b7a8-b009a4568b36', + }, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + it('should handle legacy upload_file data source', () => { mockDocumentDetail = { name: 'test-document', @@ -307,6 +321,18 @@ describe('DocumentSettings', () => { expect(screen.getByTestId('files-count')).toHaveTextContent('0') }) + it('should handle empty data_source_info object', () => { + mockDocumentDetail = { + name: 'test-document', + data_source_type: 'upload_file', + data_source_info: {}, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('0') + }) + it('should maintain structure when rerendered', () => { const { rerender } = render( , @@ -317,4 +343,37 @@ describe('DocumentSettings', () => { expect(screen.getByTestId('step-two')).toBeInTheDocument() }) }) + + describe('Files Extraction Regression Tests', () => { + it('should correctly extract file ID from upload_file_id format', () => { + const fileId = '4a807f05-45d6-4fc4-b7a8-b009a4568b36' + mockDocumentDetail = { + name: 'test-document.pdf', + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: fileId, + }, + } + + render() + + // Verify files array is populated with correct file ID + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + + it('should preserve document name when using upload_file_id format', () => { + const documentName = 'my-uploaded-document.txt' + mockDocumentDetail = { + name: documentName, + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: 'some-file-id', + }, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + }) }) diff --git a/web/app/components/datasets/documents/detail/settings/document-settings.tsx b/web/app/components/datasets/documents/detail/settings/document-settings.tsx index bcbc149231..2b6cc77683 100644 --- a/web/app/components/datasets/documents/detail/settings/document-settings.tsx +++ b/web/app/components/datasets/documents/detail/settings/document-settings.tsx @@ -8,6 +8,7 @@ import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, + UploadFileIdInfo, WebsiteCrawlInfo, } from '@/models/datasets' import { useBoolean } from 'ahooks' @@ -61,6 +62,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const dataSourceInfo = documentDetail?.data_source_info + // Type guards for DataSourceInfo union const isLegacyDataSourceInfo = (info: DataSourceInfo | undefined): info is LegacyDataSourceInfo => { return !!info && 'upload_file' in info } @@ -73,10 +75,15 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const isLocalFileInfo = (info: DataSourceInfo | undefined): info is LocalFileInfo => { return !!info && 'related_id' in info && 'transfer_method' in info } + const isUploadFileIdInfo = (info: DataSourceInfo | undefined): info is UploadFileIdInfo => { + return !!info && 'upload_file_id' in info + } + const legacyInfo = isLegacyDataSourceInfo(dataSourceInfo) ? dataSourceInfo : undefined const websiteInfo = isWebsiteCrawlInfo(dataSourceInfo) ? dataSourceInfo : undefined const onlineDocumentInfo = isOnlineDocumentInfo(dataSourceInfo) ? dataSourceInfo : undefined const localFileInfo = isLocalFileInfo(dataSourceInfo) ? dataSourceInfo : undefined + const uploadFileIdInfo = isUploadFileIdInfo(dataSourceInfo) ? dataSourceInfo : undefined const currentPage = useMemo(() => { if (legacyInfo) { @@ -101,8 +108,20 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { }, [documentDetail?.data_source_type, documentDetail?.name, legacyInfo, onlineDocumentInfo]) const files = useMemo(() => { - if (legacyInfo?.upload_file) - return [legacyInfo.upload_file as CustomFile] + // Handle upload_file_id format + if (uploadFileIdInfo) { + return [{ + id: uploadFileIdInfo.upload_file_id, + name: documentDetail?.name || '', + } as unknown as CustomFile] + } + + // Handle legacy upload_file format + if (legacyInfo?.upload_file) { + return [legacyInfo.upload_file as unknown as CustomFile] + } + + // Handle local file info format if (localFileInfo) { const { related_id, name, extension } = localFileInfo return [{ @@ -111,8 +130,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { extension, } as unknown as CustomFile] } + return [] - }, [legacyInfo?.upload_file, localFileInfo]) + }, [uploadFileIdInfo, legacyInfo?.upload_file, localFileInfo, documentDetail?.name]) const websitePages = useMemo(() => { if (!websiteInfo) diff --git a/web/models/datasets.ts b/web/models/datasets.ts index ed16e1a67c..e4793357f4 100644 --- a/web/models/datasets.ts +++ b/web/models/datasets.ts @@ -381,7 +381,11 @@ export type OnlineDriveInfo = { type: 'file' | 'folder' } -export type DataSourceInfo = LegacyDataSourceInfo | LocalFileInfo | OnlineDocumentInfo | WebsiteCrawlInfo +export type UploadFileIdInfo = { + upload_file_id: string +} + +export type DataSourceInfo = LegacyDataSourceInfo | LocalFileInfo | OnlineDocumentInfo | WebsiteCrawlInfo | UploadFileIdInfo export type InitialDocumentDetail = { id: string diff --git a/web/service/knowledge/use-create-dataset.ts b/web/service/knowledge/use-create-dataset.ts index a0d55eeb99..297bb44827 100644 --- a/web/service/knowledge/use-create-dataset.ts +++ b/web/service/knowledge/use-create-dataset.ts @@ -91,11 +91,15 @@ const getFileIndexingEstimateParamsForFile = ({ processRule, dataset_id, }: GetFileIndexingEstimateParamsOptionFile): IndexingEstimateParams => { + const fileIds = files + .map(file => file.id) + .filter((id): id is string => Boolean(id)) + return { info_list: { data_source_type: dataSourceType, file_info_list: { - file_ids: files.map(file => file.id) as string[], + file_ids: fileIds, }, }, indexing_technique: indexingTechnique, From 27c4faad4f717624564d0f537ca865eaf8559099 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 24 Mar 2026 10:52:27 +0800 Subject: [PATCH 14/32] ci: update actions version, fix cache (#33950) --- .github/actions/setup-web/action.yml | 9 +++---- .github/workflows/style.yml | 29 +++++++++++++-------- .github/workflows/translate-i18n-claude.yml | 2 +- web/package.json | 6 ++--- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 6f3b3c08b4..24af948732 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,10 +4,9 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0 + uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 with: - node-version-file: web/.nvmrc + working-directory: web + node-version-file: .nvmrc cache: true - cache-dependency-path: web/pnpm-lock.yaml - run-install: | - cwd: ./web + run-install: true diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 657a481f74..23ae36f7b1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -84,20 +84,20 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Restore ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' + id: eslint-cache-restore + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: | - vp run lint:ci - # pnpm run lint:report - # continue-on-error: true - - # - name: Annotate Code - # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' - # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae - # with: - # eslint-report: web/eslint_report.json - # github-token: ${{ secrets.GITHUB_TOKEN }} + run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' @@ -114,6 +114,13 @@ jobs: working-directory: ./web run: vp run knip + - name: Save ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 84f8000a01..1869254295 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -120,7 +120,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76 + uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/web/package.json b/web/package.json index 7d82b6afde..ee8dbb466e 100644 --- a/web/package.json +++ b/web/package.json @@ -39,9 +39,9 @@ "i18n:check": "tsx ./scripts/check-i18n.js", "knip": "knip", "lint": "eslint --cache --concurrency=auto", - "lint:ci": "eslint --cache --concurrency 2", - "lint:fix": "pnpm lint --fix", - "lint:quiet": "pnpm lint --quiet", + "lint:ci": "eslint --cache --cache-strategy content --concurrency 2", + "lint:fix": "vp run lint --fix", + "lint:quiet": "vp run lint --quiet", "lint:tss": "tsslint --project tsconfig.json", "preinstall": "npx only-allow pnpm", "prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky", From 0589fa423bdee5a4e7f907fa2459bb956c6af48f Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 11:24:31 +0800 Subject: [PATCH 15/32] fix(sdk): patch flatted vulnerability in nodejs client lockfile (#33996) --- sdks/nodejs-client/package.json | 1 + sdks/nodejs-client/pnpm-lock.yaml | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 7c8a293446..728aa0d054 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -69,6 +69,7 @@ }, "pnpm": { "overrides": { + "flatted@<=3.4.1": "3.4.2", "rollup@>=4.0.0,<4.59.0": "4.59.0" } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index c4b299cd73..c9081420f5 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -5,6 +5,7 @@ settings: excludeLinksFromLockfile: false overrides: + flatted@<=3.4.1: 3.4.2 rollup@>=4.0.0,<4.59.0: 4.59.0 importers: @@ -324,66 +325,79 @@ packages: resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.59.0': resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] + libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.59.0': resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.59.0': resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] + libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.59.0': resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.59.0': resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} cpu: [loong64] os: [linux] + libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.59.0': resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.59.0': resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} cpu: [ppc64] os: [linux] + libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.59.0': resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.59.0': resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] + libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.59.0': resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.59.0': resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-musl@4.59.0': resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] + libc: [musl] '@rollup/rollup-openbsd-x64@4.59.0': resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} @@ -741,8 +755,8 @@ packages: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} - flatted@3.4.1: - resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} + flatted@3.4.2: + resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==} follow-redirects@1.15.11: resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} @@ -1836,10 +1850,10 @@ snapshots: flat-cache@4.0.1: dependencies: - flatted: 3.4.1 + flatted: 3.4.2 keyv: 4.5.4 - flatted@3.4.1: {} + flatted@3.4.2: {} follow-redirects@1.15.11: {} From ecd3a964c1f5599642c25c9c8d8e753e8763d07f Mon Sep 17 00:00:00 2001 From: BitToby <218712309+bittoby@users.noreply.github.com> Date: Tue, 24 Mar 2026 06:22:17 +0200 Subject: [PATCH 16/32] refactor(api): type auth service credentials with TypedDict (#33867) --- api/services/auth/api_key_auth_base.py | 10 +++++++++- api/services/auth/api_key_auth_factory.py | 4 ++-- api/services/auth/firecrawl/firecrawl.py | 4 ++-- api/services/auth/jina.py | 4 ++-- api/services/auth/jina/jina.py | 4 ++-- api/services/auth/watercrawl/watercrawl.py | 4 ++-- .../unit_tests/services/auth/test_api_key_auth_base.py | 6 +++--- .../services/auth/test_api_key_auth_factory.py | 4 ++-- 8 files changed, 24 insertions(+), 16 deletions(-) diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b5..2e1b723e82 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b0768..6e183b70e3 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b002706931..c9e5610aea 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d1..cbdc908690 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py index b5d91ef3fb..388504c07f 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py @@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase): class TestApiKeyAuthBase: def test_should_store_credentials_on_init(self): """Test that credentials are properly stored during initialization""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) assert auth.credentials == credentials def test_should_not_instantiate_abstract_class(self): """Test that ApiKeyAuthBase cannot be instantiated directly""" - credentials = {"api_key": "test_key"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} with pytest.raises(TypeError) as exc_info: ApiKeyAuthBase(credentials) @@ -29,7 +29,7 @@ class TestApiKeyAuthBase: def test_should_allow_subclass_implementation(self): """Test that subclasses can properly implement the abstract method""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) # Should not raise any exception diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py index 60af6e20c2..b1f7cf24f3 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py @@ -58,7 +58,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) result = factory.validate_credentials() # Assert @@ -75,7 +75,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act & Assert - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) with pytest.raises(Exception) as exc_info: factory.validate_credentials() assert str(exc_info.value) == "Authentication error" From 8b634a9bee3a3238f45540082459031dd155f0f2 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Tue, 24 Mar 2026 05:27:50 +0100 Subject: [PATCH 17/32] =?UTF-8?q?refactor:=20use=20EnumText=20for=20ApiToo?= =?UTF-8?q?lProvider.schema=5Ftype=5Fstr=20and=20Docume=E2=80=A6=20(#33983?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/commands/vector.py | 3 +- api/models/dataset.py | 4 +- api/models/tools.py | 4 +- api/services/dataset_service.py | 20 ++++---- .../rag_pipeline_transform_service.py | 9 ++-- .../batch_create_segment_to_index_task.py | 5 +- api/tasks/document_indexing_task.py | 3 +- api/tasks/regenerate_summary_index_task.py | 5 +- .../test_dataset_retrieval_integration.py | 15 +++--- .../services/document_service_status.py | 3 +- .../services/test_dataset_service.py | 3 +- ...et_service_batch_update_document_status.py | 3 +- .../test_dataset_service_delete_dataset.py | 5 +- .../test_document_service_display_status.py | 3 +- .../test_document_service_rename_document.py | 3 +- .../services/test_metadata_service.py | 3 +- .../tools/test_tools_transform_service.py | 10 ++-- .../tasks/test_batch_clean_document_task.py | 16 ++++-- ...test_batch_create_segment_to_index_task.py | 19 +++---- .../tasks/test_clean_dataset_task.py | 3 +- .../tasks/test_clean_notion_document_task.py | 5 +- .../test_create_segment_to_index_task.py | 11 ++-- .../test_deal_dataset_vector_index_task.py | 51 ++++++++++--------- .../test_disable_segment_from_index_task.py | 9 +++- .../test_disable_segments_from_index_task.py | 9 +++- .../tasks/test_document_indexing_sync_task.py | 3 +- .../test_document_indexing_update_task.py | 3 +- .../test_duplicate_document_indexing_task.py | 7 +-- .../console/datasets/test_data_source.py | 3 +- .../console/datasets/test_datasets.py | 3 +- .../datasets/test_datasets_document.py | 19 +++---- .../datasets/test_datasets_segments.py | 5 +- .../controllers/service_api/conftest.py | 3 +- .../dataset/test_dataset_segment.py | 11 ++-- .../service_api/dataset/test_document.py | 33 ++++++++---- .../rag/retrieval/test_dataset_retrieval.py | 4 +- .../unit_tests/models/test_tool_models.py | 30 +++++------ .../services/document_service_validation.py | 15 +++--- .../unit_tests/services/segment_service.py | 7 +-- .../test_dataset_service_lock_not_owned.py | 7 +-- .../services/test_summary_index_service.py | 15 +++--- .../services/test_vector_service.py | 13 ++--- .../unit_tests/services/vector_service.py | 9 ++-- .../tasks/test_clean_dataset_task.py | 15 +++--- .../tasks/test_dataset_indexing_task.py | 3 +- .../tasks/test_document_indexing_sync_task.py | 3 +- 46 files changed, 255 insertions(+), 180 deletions(-) diff --git a/api/commands/vector.py b/api/commands/vector.py index 4cf11c9ad1..bef18bf73b 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,6 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -269,7 +270,7 @@ def migrate_knowledge_vector_database(): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == "hierarchical_model": + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] diff --git a/api/models/dataset.py b/api/models/dataset.py index 4c6152ed3f..b4fb03a7f4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -496,7 +496,9 @@ class Document(Base): ) doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) + doc_form: Mapped[IndexStructureType] = mapped_column( + EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'") + ) doc_language = mapped_column(String(255), nullable=True) need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) diff --git a/api/models/tools.py b/api/models/tools.py index 01182af867..63b27b9413 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase): icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema: Mapped[str] = mapped_column(LongText, nullable=False) - schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) + schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column( + EnumText(ApiProviderSchemaType, length=40), nullable=False + ) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ba4ab6757f..65e112f1e9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1440,7 +1440,7 @@ class DocumentService: .filter( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, - Document.doc_form != "qa_model", # Skip qa_model documents + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .update({Document.need_summary: need_summary}, synchronize_session=False) ) @@ -2040,7 +2040,7 @@ class DocumentService: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = knowledge_config.doc_form + document.doc_form = IndexStructureType(knowledge_config.doc_form) document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch @@ -2640,7 +2640,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = document_data.doc_form + document.doc_form = IndexStructureType(document_data.doc_form) db.session.add(document) db.session.commit() # update document segment @@ -3101,7 +3101,7 @@ class DocumentService: class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") if not args["answer"].strip(): @@ -3158,7 +3158,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] @@ -3232,7 +3232,7 @@ class SegmentService: tokens = 0 if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( texts=[content + segment_item["answer"]] )[0] @@ -3255,7 +3255,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count @@ -3322,7 +3322,7 @@ class SegmentService: content = args.content or segment.content if segment.content == content: segment.word_count = len(content) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3419,7 +3419,7 @@ class SegmentService: ) # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: @@ -3436,7 +3436,7 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 1d0aafd5fd..7dcfecdd1d 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,6 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -79,9 +80,9 @@ class RagPipelineTransformService: pipeline = self._create_pipeline(pipeline_yaml) # save chunk structure to dataset - if doc_form == "hierarchical_model": + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset.chunk_structure = "hierarchical_model" - elif doc_form == "text_model": + elif doc_form == IndexStructureType.PARAGRAPH_INDEX: dataset.chunk_structure = "text_model" else: raise ValueError("Unsupported doc form") @@ -101,7 +102,7 @@ class RagPipelineTransformService: def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} - if doc_form == "text_model": + if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: if indexing_technique == "high_quality": @@ -132,7 +133,7 @@ class RagPipelineTransformService: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") - elif doc_form == "hierarchical_model": + elif doc_form == IndexStructureType.PARENT_CHILD_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee00919..7f810129ef 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,6 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexStructureType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -109,7 +110,7 @@ def batch_create_segment_to_index_task( df = pd.read_csv(file_path) content = [] for _, row in df.iterrows(): - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: data = {"content": row.iloc[0], "answer": row.iloc[1]} else: data = {"content": row.iloc[0]} @@ -159,7 +160,7 @@ def batch_create_segment_to_index_task( status="completed", completed_at=naive_utc_now(), ) - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: segment_document.answer = segment["answer"] segment_document.word_count += len(segment["answer"]) word_count_change += segment_document.word_count diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e05d63426c..b5794e33e2 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,6 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) if ( document.indexing_status == IndexingStatus.COMPLETED - and document.doc_form != "qa_model" + and document.doc_form != IndexStructureType.QA_INDEX and document.need_summary is True ): try: diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 39c2f4103e..ac5d23408a 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,6 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -106,7 +107,7 @@ def regenerate_summary_index_task( ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived - DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .all() @@ -209,7 +210,7 @@ def regenerate_summary_index_task( for dataset_document in dataset_documents: # Skip qa_model documents - if dataset_document.doc_form == "qa_model": + if dataset_document.doc_form == IndexStructureType.QA_INDEX: continue try: diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 781e297fa4..ea8d04502a 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration: name=f"Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived @@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, @@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=status, # Not completed enabled=True, archived=False, @@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index f995ac7bef..42d587b7f7 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id document.indexing_status = indexing_status diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604..a484c7be87 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory: created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.flush() diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index ab7e2a3f50..c1d088755c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id or str(uuid4()) document.enabled = enabled diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index ed070527c9..807d18322c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory: tenant_id: str, dataset_id: str, created_by: str, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """Persist a document so dataset.doc_form resolves through the real document path.""" document = Document( @@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset: tenant_id=tenant.id, dataset_id=dataset.id, created_by=owner.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Act diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index 47d259d8a0..c0047df810 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -3,6 +3,7 @@ from uuid import uuid4 from sqlalchemy import select +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -42,7 +43,7 @@ def _create_document( name=f"doc-{uuid4()}", created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = str(uuid4()) document.indexing_status = indexing_status diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index bffa520ce6..34532ed7f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -69,7 +70,7 @@ def make_document( name=name, created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) doc.id = document_id doc.indexing_status = "completed" diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e847329c5b..8b1349be9a 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -5,6 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom @@ -139,7 +140,7 @@ class TestMetadataService: name=fake.file_name(), created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 0f38218c51..7ab059bb75 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -52,7 +52,7 @@ class TestToolTransformService: user_id="test_user_id", credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) elif provider_type == "builtin": @@ -659,7 +659,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -695,7 +695,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -731,7 +731,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 210d9eb39e..6cbbe43137 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask: created_from=DocumentCreatedFrom.WEB, created_by=account.id, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) @@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Execute the task with non-existent dataset - batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) + batch_clean_document_task( + document_ids=[document_id], + dataset_id=dataset_id, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + file_ids=[], + ) # Verify that no index processing occurred mock_external_service_dependencies["index_processor"].clean.assert_not_called() @@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask: account = self._create_test_account(db_session_with_containers) # Test different doc_form types - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 202ccb0098..5ebf141828 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,6 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ) @@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask: return upload_file - def _create_test_csv_content(self, content_type="text_model"): + def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX): """ Helper method to create test CSV content. Args: - content_type: Type of content to create ("text_model" or "qa_model") + content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX) Returns: str: CSV content as string """ - if content_type == "qa_model": + if content_type == IndexStructureType.QA_INDEX: csv_content = "content,answer\n" csv_content += "This is the first segment content,This is the first answer\n" csv_content += "This is the second segment content,This is the second answer\n" @@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] @@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Archived document @@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Document with incomplete indexing @@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), ] @@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask: db_session_with_containers.commit() # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1cd698b870..9449fee0af 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,6 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -192,7 +193,7 @@ class TestCleanDatasetTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=100, created_at=datetime.now(), updated_at=datetime.now(), diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index a2a190fd69..926c839c8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask: name=f"Notion Page {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", # Set doc_form to ensure dataset.doc_form works + doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works doc_language="en", indexing_status=IndexingStatus.COMPLETED, ) @@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask: # Test different index types # Note: Only testing text_model to avoid dependency on external services - index_types = ["text_model"] + index_types = [IndexStructureType.PARAGRAPH_INDEX] for index_type in index_types: # Create dataset (doc_form will be set via document creation) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 132f43c320..979435282b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask: - Processing completes successfully for different forms """ # Arrange: Test different doc_forms - doc_forms = ["qa_model", "text_model", "web_model"] + doc_forms = [ + IndexStructureType.QA_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + ] for doc_form in doc_forms: # Create fresh test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index e80b37ac1b..d457b59d58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="qa_index", + doc_form=IndexStructureType.QA_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type - mock_index_processor_factory.assert_called_once_with("qa_index") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type - mock_index_processor_factory.assert_called_once_with("text_model") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask: name=f"Test Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask: name="Enabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask: name="Disabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped @@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask: name="Active Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask: name="Archived Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask: name="Completed Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask: name="Incomplete Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index da42fc7167..d21f1daf23 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,6 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask: dataset: Dataset, tenant: Tenant, account: Account, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """ Helper method to create a test document. @@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask: - Index processor clean method is called correctly """ # Test different document forms - doc_forms = ["text_model", "qa_model", "table_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Arrange: Create test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 4bc9bb4749..fbcb7b5264 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask: document.indexing_status = "completed" document.enabled = True document.archived = False - document.doc_form = "text_model" # Use text_model form for testing + document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing document.doc_language = "en" db_session_with_containers.add(document) db_session_with_containers.commit() @@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask: segment_ids = [segment.id for segment in segments] # Test different document forms - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Update document form diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 6a17a19a54..10d97919fb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,6 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory: created_by=created_by, indexing_status=indexing_status, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 2fbea1388c..c650d56091 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index f1f5a4b105..76b6a8ae73 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=dataset.created_by, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) extra_documents.append(document) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adf..d841f67f9b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import ( DataSourceNotionDocumentSyncApi, DataSourceNotionListApi, ) +from core.rag.index_processor.constant.index_type import IndexStructureType def unwrap(func): @@ -343,7 +344,7 @@ class TestDataSourceNotionApi: } ], "process_rule": {"rules": {}}, - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 0ee76e504b..68a7b30b9e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import ( from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import ApiToken, UploadFile @@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi: }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f23dd5b44a..f08f21ee14 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType, IndexingStatus @@ -66,7 +67,7 @@ def document(): indexing_status=IndexingStatus.INDEXING, data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, archived=False, is_paused=False, dataset_process_rule=None, @@ -765,8 +766,8 @@ class TestDocumentGenerateSummaryApi: summary_index_setting={"enable": True}, ) - doc1 = MagicMock(id="doc-1", doc_form="qa_model") - doc2 = MagicMock(id="doc-2", doc_form="text") + doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX) + doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX) payload = {"document_list": ["doc-1", "doc-2"]} @@ -822,7 +823,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -849,7 +850,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -973,7 +974,7 @@ class TestDocumentBatchIndexingEstimateApi: "mode": "single", "only_main_content": True, }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1001,7 +1002,7 @@ class TestDocumentBatchIndexingEstimateApi: "notion_page_id": "p1", "type": "page", }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1024,7 +1025,7 @@ class TestDocumentBatchIndexingEstimateApi: indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): @@ -1353,7 +1354,7 @@ class TestDocumentIndexingEdgeCases: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index e67e4daad9..1482499c41 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -24,6 +24,7 @@ from controllers.console.datasets.error import ( InvalidActionError, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() segment.id = "seg-1" @@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 4337a0c8c0..01d2d1e7c0 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -12,6 +12,7 @@ from unittest.mock import Mock import pytest from flask import Flask +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -175,7 +176,7 @@ def mock_document(): document.name = "test_document.txt" document.indexing_status = "completed" document.enabled = True - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX return document diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5c48ef1804..73a87761d5 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import ( SegmentCreatePayload, SegmentListQuery, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -788,7 +789,7 @@ class TestSegmentApiGet: # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc_svc.get_document.return_value = Mock(doc_form="text_model") + mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -903,7 +904,7 @@ class TestSegmentApiPost: mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.segment_create_args_validate.return_value = None @@ -1091,7 +1092,7 @@ class TestDatasetSegmentApiDelete: mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = None # Segment not found @@ -1371,7 +1372,7 @@ class TestDatasetSegmentApiGetSingle: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None - mock_doc = Mock(doc_form="text_model") + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} @@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle: assert status == 200 assert "data" in response - assert response["doc_form"] == "text_model" + assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index e6e841be19..7f77e61ee4 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload: def test_payload_with_defaults(self): """Test payload default values.""" payload = DocumentTextCreatePayload(name="Doc", text="Content") - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" assert payload.process_rule is None assert payload.indexing_technique is None @@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload: payload = DocumentTextCreatePayload( name="Full Document", text="Complete document content here", - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, doc_language="Chinese", indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) assert payload.name == "Full Document" - assert payload.doc_form == "qa_model" + assert payload.doc_form == IndexStructureType.QA_INDEX assert payload.doc_language == "Chinese" assert payload.indexing_technique == "high_quality" assert payload.embedding_model == "text-embedding-ada-002" @@ -147,8 +148,8 @@ class TestDocumentTextUpdate: def test_payload_with_doc_form_update(self): """Test payload with doc_form update.""" - payload = DocumentTextUpdate(doc_form="qa_model") - assert payload.doc_form == "qa_model" + payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX) + assert payload.doc_form == IndexStructureType.QA_INDEX def test_payload_with_language_update(self): """Test payload with doc_language update.""" @@ -158,7 +159,7 @@ class TestDocumentTextUpdate: def test_payload_default_values(self): """Test payload default values.""" payload = DocumentTextUpdate() - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" @@ -272,14 +273,24 @@ class TestDocumentDocForm: def test_text_model_form(self): """Test text_model form.""" - doc_form = "text_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.PARAGRAPH_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms def test_qa_model_form(self): """Test qa_model form.""" - doc_form = "qa_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.QA_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms @@ -504,7 +515,7 @@ class TestDocumentApiGet: doc.name = "test_document.txt" doc.indexing_status = "completed" doc.enabled = True - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.doc_type = "book" doc.doc_metadata_details = {"source": "upload"} diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 665e98bd9c..a34ca330ca 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4800,8 +4800,8 @@ class TestInternalHooksCoverage: dataset_docs = [ SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), - SimpleNamespace(id="doc-c", doc_form="qa_model"), - SimpleNamespace(id="doc-d", doc_form="qa_model"), + SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX), + SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX), ] child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] diff --git a/api/tests/unit_tests/models/test_tool_models.py b/api/tests/unit_tests/models/test_tool_models.py index a6c2eae2c0..8e3c4da904 100644 --- a/api/tests/unit_tests/models/test_tool_models.py +++ b/api/tests/unit_tests/models/test_tool_models.py @@ -238,7 +238,7 @@ class TestApiToolProviderValidation: name=provider_name, icon='{"type": "emoji", "value": "🔧"}', schema=schema, - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Custom API for testing", tools_str=json.dumps(tools), credentials_str=json.dumps(credentials), @@ -249,7 +249,7 @@ class TestApiToolProviderValidation: assert api_provider.user_id == user_id assert api_provider.name == provider_name assert api_provider.schema == schema - assert api_provider.schema_type_str == "openapi" + assert api_provider.schema_type_str == ApiProviderSchemaType.OPENAPI assert api_provider.description == "Custom API for testing" def test_api_tool_provider_schema_type_property(self): @@ -261,7 +261,7 @@ class TestApiToolProviderValidation: name="Test API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -314,7 +314,7 @@ class TestApiToolProviderValidation: name="Weather API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Weather API", tools_str=json.dumps(tools_data), credentials_str="{}", @@ -343,7 +343,7 @@ class TestApiToolProviderValidation: name="Secure API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Secure API", tools_str="[]", credentials_str=json.dumps(credentials_data), @@ -369,7 +369,7 @@ class TestApiToolProviderValidation: name="Privacy API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with privacy policy", tools_str="[]", credentials_str="{}", @@ -391,7 +391,7 @@ class TestApiToolProviderValidation: name="Disclaimer API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with disclaimer", tools_str="[]", credentials_str="{}", @@ -410,7 +410,7 @@ class TestApiToolProviderValidation: name="Default API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API", tools_str="[]", credentials_str="{}", @@ -432,7 +432,7 @@ class TestApiToolProviderValidation: name=provider_name, icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Unique API", tools_str="[]", credentials_str="{}", @@ -454,7 +454,7 @@ class TestApiToolProviderValidation: name="Public API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Public API with no auth", tools_str="[]", credentials_str=json.dumps(credentials), @@ -479,7 +479,7 @@ class TestApiToolProviderValidation: name="Query Auth API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with query auth", tools_str="[]", credentials_str=json.dumps(credentials), @@ -741,7 +741,7 @@ class TestCredentialStorage: name="Test API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str=json.dumps(credentials), @@ -788,7 +788,7 @@ class TestCredentialStorage: name="Update Test", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str=json.dumps(original_credentials), @@ -897,7 +897,7 @@ class TestToolProviderRelationships: name="User API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -931,7 +931,7 @@ class TestToolProviderRelationships: name="Custom API 1", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 6829691507..1f68ff6b3d 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,6 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService @@ -188,7 +189,7 @@ class DocumentValidationTestDataFactory: def create_knowledge_config_mock( data_source: DataSource | None = None, process_rule: ProcessRule | None = None, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, indexing_technique: str = "high_quality", **kwargs, ) -> Mock: @@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm: - Validation logic works correctly """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "text_model" + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None) - doc_form = "text_model" + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm: - Error type is correct """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "table_model" # Different form + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card") - doc_form = "text_model" # Different form + doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index cc2c0a8032..5e625fa0cd 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import SegmentType @@ -91,7 +92,7 @@ class SegmentTestDataFactory: document_id: str = "doc-123", dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, word_count: int = 100, **kwargs, ) -> Mock: @@ -210,7 +211,7 @@ class TestSegmentServiceCreateSegment: def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): """Test creation of segment with QA model (requires answer).""" # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} @@ -429,7 +430,7 @@ class TestSegmentServiceUpdateSegment: """Test update segment with QA model (includes answer).""" # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index bd226f7536..d2287e8982 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -76,7 +77,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( info_list = types.SimpleNamespace(data_source_type="upload_file") data_source = types.SimpleNamespace(info_list=info_list) knowledge_config = types.SimpleNamespace( - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, indexing_technique="high_quality", @@ -131,7 +132,7 @@ def test_add_segment_ignores_lock_not_owned( document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX # Minimal args required by add_segment args = { @@ -174,4 +175,4 @@ def test_multi_create_segment_ignores_lock_not_owned( document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index be64e431ba..c4285c73a0 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: if has_document: doc = MagicMock(name="document") doc.doc_language = "en" - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX segment.document = doc else: segment.document = None @@ -623,13 +624,13 @@ def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.Mon dataset = _dataset(indexing_technique="economy") document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] dataset = _dataset() assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] @@ -637,7 +638,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg1 = _segment() seg2 = _segment() @@ -673,7 +674,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() query = MagicMock() @@ -696,7 +697,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg = _segment() session = MagicMock() @@ -935,7 +936,7 @@ def test_update_summary_for_segment_skip_conditions() -> None: SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None ) seg = _segment(has_document=True) - seg.document.doc_form = "qa_model" + seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 7b0103a2a1..d3a98dd4bb 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module +from core.rag.index_processor.constant.index_type import IndexStructureType from services.vector_service import VectorService @@ -32,7 +33,7 @@ class _ParentDocStub: def _make_dataset( *, indexing_technique: str = "high_quality", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", is_multimodal: bool = False, @@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_called_once() args, kwargs = index_processor.load.call_args @@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) assert index_processor.load.call_count == 2 first_args, first_kwargs = index_processor.load.call_args_list[0] @@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_not_called() @@ -392,7 +393,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1") segment = _make_segment(segment_id="seg-1") dataset_document = MagicMock() @@ -439,7 +440,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX) segment = _make_segment() dataset_document = MagicMock() dataset_document.doc_language = "en" diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index c99275c6b2..e180063041 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,6 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -151,7 +152,7 @@ class VectorServiceTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, indexing_technique: str = "high_quality", embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", @@ -493,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality" ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -505,7 +506,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model") + VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_called_once() @@ -649,7 +650,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 74ba7f9c34..c0a4d2f113 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -186,7 +187,7 @@ class TestErrorHandling: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -231,7 +232,7 @@ class TestPipelineAndWorkflowDeletion: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=pipeline_id, ) @@ -267,7 +268,7 @@ class TestPipelineAndWorkflowDeletion: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=None, ) @@ -323,7 +324,7 @@ class TestSegmentAttachmentCleanup: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -368,7 +369,7 @@ class TestSegmentAttachmentCleanup: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert - storage delete was attempted @@ -410,7 +411,7 @@ class TestEdgeCases: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -454,7 +455,7 @@ class TestIndexProcessorParameters: indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 8a721124d6..6804ade5aa 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -222,7 +223,7 @@ def mock_documents(document_ids, dataset_id): doc.stopped_at = None doc.processing_started_at = None # optional attribute used in some code paths - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX documents.append(doc) return documents diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 3668416e36..f49f4535af 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, document.tenant_id = str(uuid.uuid4()) document.data_source_type = "notion_import" document.indexing_status = "completed" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX document.data_source_info_dict = { "notion_workspace_id": notion_workspace_id, "notion_page_id": notion_page_id, From b0920ecd17743a55638bded693e3d93237ebfcc5 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:02:52 +0800 Subject: [PATCH 18/32] refactor(web): migrate plugin toast usage to new UI toast API and update tests (#34001) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../plugins/plugin-install-flow.test.ts | 12 +++- .../install-plugin/__tests__/hooks.spec.ts | 20 ++++--- .../plugins/install-plugin/hooks.ts | 20 ++----- .../__tests__/index.spec.tsx | 14 +++-- .../install-from-github/index.tsx | 21 ++----- .../__tests__/detail-header.spec.tsx | 19 ++++++- .../__tests__/endpoint-card.spec.tsx | 19 ++++++- .../__tests__/endpoint-modal.spec.tsx | 20 ++++++- .../__tests__/use-plugin-operations.spec.ts | 21 ++++++- .../hooks/use-plugin-operations.ts | 14 ++--- .../plugin-detail-panel/endpoint-card.tsx | 20 +++---- .../plugin-detail-panel/endpoint-list.tsx | 12 ++-- .../plugin-detail-panel/endpoint-modal.tsx | 13 +++-- .../model-selector/__tests__/index.spec.tsx | 32 +++++++---- .../model-selector/index.tsx | 9 +-- .../__tests__/log-viewer.spec.tsx | 20 +++++-- .../__tests__/selector-entry.spec.tsx | 14 +++-- .../__tests__/selector-view.spec.tsx | 14 ++++- .../__tests__/subscription-card.spec.tsx | 14 ++++- .../create/__tests__/common-modal.spec.tsx | 14 +++-- .../create/__tests__/index.spec.tsx | 21 ++++--- .../create/__tests__/oauth-client.spec.tsx | 17 ++++-- .../__tests__/use-oauth-client-state.spec.ts | 17 ++++-- .../create/hooks/use-common-modal-state.ts | 42 +++----------- .../create/hooks/use-oauth-client-state.ts | 27 ++------- .../subscription-list/create/index.tsx | 12 +--- .../subscription-list/create/oauth-client.tsx | 7 +-- .../edit/__tests__/apikey-edit-modal.spec.tsx | 16 ++++-- .../edit/__tests__/index.spec.tsx | 12 +++- .../edit/__tests__/manual-edit-modal.spec.tsx | 16 ++++-- .../edit/__tests__/oauth-edit-modal.spec.tsx | 16 ++++-- .../edit/apikey-edit-modal.tsx | 24 ++------ .../edit/manual-edit-modal.tsx | 12 +--- .../edit/oauth-edit-modal.tsx | 12 +--- .../subscription-list/log-viewer.tsx | 7 +-- .../tool-selector/__tests__/index.spec.tsx | 14 ++++- .../__tests__/tool-credentials-form.spec.tsx | 16 ++++-- .../components/tool-credentials-form.tsx | 7 ++- .../plugin-item/__tests__/action.spec.tsx | 31 ++++++---- .../components/plugins/plugin-item/action.tsx | 4 +- web/eslint-suppressions.json | 57 +++---------------- 41 files changed, 390 insertions(+), 339 deletions(-) diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 8edb6705d4..8fa2246198 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -12,8 +12,16 @@ vi.mock('@/config', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockUploadGitHub = vi.fn() diff --git a/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts b/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts index 918a9b36e3..6b0fc27adf 100644 --- a/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts +++ b/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts @@ -3,8 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { useGitHubReleases, useGitHubUpload } from '../hooks' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((...args: unknown[]) => mockNotify(...args), { + success: (...args: unknown[]) => mockNotify(...args), + error: (...args: unknown[]) => mockNotify(...args), + warning: (...args: unknown[]) => mockNotify(...args), + info: (...args: unknown[]) => mockNotify(...args), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) vi.mock('@/config', () => ({ @@ -56,9 +64,7 @@ describe('install-plugin/hooks', () => { const releases = await result.current.fetchReleases('owner', 'repo') expect(releases).toEqual([]) - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockNotify).toHaveBeenCalledWith('Failed to fetch repository releases') }) }) @@ -130,9 +136,7 @@ describe('install-plugin/hooks', () => { await expect( result.current.handleUpload('url', 'v1', 'pkg'), ).rejects.toThrow('Upload failed') - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error', message: 'Error uploading package' }), - ) + expect(mockNotify).toHaveBeenCalledWith('Error uploading package') }) }) }) diff --git a/web/app/components/plugins/install-plugin/hooks.ts b/web/app/components/plugins/install-plugin/hooks.ts index 2addba4a04..cc7148cc17 100644 --- a/web/app/components/plugins/install-plugin/hooks.ts +++ b/web/app/components/plugins/install-plugin/hooks.ts @@ -1,6 +1,5 @@ import type { GitHubRepoReleaseResponse } from '../types' -import type { IToastProps } from '@/app/components/base/toast' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { GITHUB_ACCESS_TOKEN } from '@/config' import { uploadGitHub } from '@/service/plugins' import { compareVersion, getLatestVersion } from '@/utils/semver' @@ -37,16 +36,10 @@ export const useGitHubReleases = () => { } catch (error) { if (error instanceof Error) { - Toast.notify({ - type: 'error', - message: error.message, - }) + toast.error(error.message) } else { - Toast.notify({ - type: 'error', - message: 'Failed to fetch repository releases', - }) + toast.error('Failed to fetch repository releases') } return [] } @@ -54,7 +47,7 @@ export const useGitHubReleases = () => { const checkForUpdates = (fetchedReleases: GitHubRepoReleaseResponse[], currentVersion: string) => { let needUpdate = false - const toastProps: IToastProps = { + const toastProps: { type?: 'success' | 'error' | 'info' | 'warning', message: string } = { type: 'info', message: 'No new version available', } @@ -99,10 +92,7 @@ export const useGitHubUpload = () => { return GitHubPackage } catch (error) { - Toast.notify({ - type: 'error', - message: 'Error uploading package', - }) + toast.error('Error uploading package') throw error } } diff --git a/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx b/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx index 0fe6b88ed8..8abec7817b 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx @@ -57,10 +57,16 @@ const createUpdatePayload = (overrides: Partial = {}): // Mock external dependencies const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (props: { type: string, message: string }) => mockNotify(props), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((props: { type: string, message: string }) => mockNotify(props), { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockGetIconUrl = vi.fn() diff --git a/web/app/components/plugins/install-plugin/install-from-github/index.tsx b/web/app/components/plugins/install-plugin/install-from-github/index.tsx index 91425031cf..ff51698478 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/index.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/index.tsx @@ -7,7 +7,7 @@ import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import useGetIcon from '@/app/components/plugins/install-plugin/base/use-get-icon' import { cn } from '@/utils/classnames' import { InstallStepFromGitHub } from '../../types' @@ -81,10 +81,7 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on const handleUrlSubmit = async () => { const { isValid, owner, repo } = parseGitHubUrl(state.repoUrl) if (!isValid || !owner || !repo) { - Toast.notify({ - type: 'error', - message: t('error.inValidGitHubUrl', { ns: 'plugin' }), - }) + toast.error(t('error.inValidGitHubUrl', { ns: 'plugin' })) return } try { @@ -97,17 +94,11 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on })) } else { - Toast.notify({ - type: 'error', - message: t('error.noReleasesFound', { ns: 'plugin' }), - }) + toast.error(t('error.noReleasesFound', { ns: 'plugin' })) } } catch { - Toast.notify({ - type: 'error', - message: t('error.fetchReleasesError', { ns: 'plugin' }), - }) + toast.error(t('error.fetchReleasesError', { ns: 'plugin' })) } } @@ -175,10 +166,10 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on >
-
+
{getTitle()}
-
+
{!([InstallStepFromGitHub.uploadFailed, InstallStepFromGitHub.installed, InstallStepFromGitHub.installFailed].includes(state.step)) && t('installFromGitHub.installNote', { ns: 'plugin' })}
diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx index f0ec5b6c83..f8d6488128 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx @@ -2,10 +2,25 @@ import type { PluginDetail } from '../../types' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import * as amplitude from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' import { PluginSource } from '../../types' import DetailHeader from '../detail-header' +const { mockToast } = vi.hoisted(() => ({ + mockToast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: mockToast, +})) + const { mockSetShowUpdatePluginModal, mockRefreshModelProviders, @@ -272,7 +287,7 @@ describe('DetailHeader', () => { vi.clearAllMocks() mockAutoUpgradeInfo = null mockEnableMarketplace = true - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + vi.clearAllMocks() vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {}) }) diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx index 480f399c91..237c72adf0 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx @@ -1,7 +1,6 @@ import type { EndpointListItem, PluginDetail } from '../../types' import { act, fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import EndpointCard from '../endpoint-card' const mockHandleChange = vi.fn() @@ -9,6 +8,22 @@ const mockEnableEndpoint = vi.fn() const mockDisableEndpoint = vi.fn() const mockDeleteEndpoint = vi.fn() const mockUpdateEndpoint = vi.fn() +const mockToastNotify = vi.fn() + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) // Flags to control whether operations should fail const failureFlags = { @@ -127,8 +142,6 @@ describe('EndpointCard', () => { failureFlags.disable = false failureFlags.delete = false failureFlags.update = false - // Mock Toast.notify to prevent toast elements from accumulating in DOM - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) // Polyfill document.execCommand for copy-to-clipboard in jsdom if (typeof document.execCommand !== 'function') { document.execCommand = vi.fn().mockReturnValue(true) diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx index 1dfe31c6b1..a467de7142 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx @@ -2,9 +2,25 @@ import type { FormSchema } from '../../../base/form/types' import type { PluginDetail } from '../../types' import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import EndpointModal from '../endpoint-modal' +const mockToastNotify = vi.fn() + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + vi.mock('@/hooks/use-i18n', () => ({ useRenderI18nObject: () => (obj: Record | string) => typeof obj === 'string' ? obj : obj?.en_US || '', @@ -69,11 +85,9 @@ const mockPluginDetail: PluginDetail = { describe('EndpointModal', () => { const mockOnCancel = vi.fn() const mockOnSaved = vi.fn() - let mockToastNotify: ReturnType beforeEach(() => { vi.clearAllMocks() - mockToastNotify = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('Rendering', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts index 0fcec7f16b..77d41c5bce 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts @@ -3,7 +3,6 @@ import type { ModalStates, VersionTarget } from '../use-detail-header-state' import { act, renderHook } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import * as amplitude from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' import { PluginSource } from '../../../../types' import { usePluginOperations } from '../use-plugin-operations' @@ -20,6 +19,7 @@ const { mockUninstallPlugin, mockFetchReleases, mockCheckForUpdates, + mockToastNotify, } = vi.hoisted(() => { return { mockSetShowUpdatePluginModal: vi.fn(), @@ -29,9 +29,25 @@ const { mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })), mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])), mockCheckForUpdates: vi.fn(() => ({ needUpdate: true, toastProps: { type: 'success', message: 'Update available' } })), + mockToastNotify: vi.fn(), } }) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + vi.mock('@/context/modal-context', () => ({ useModalContext: () => ({ setShowUpdatePluginModal: mockSetShowUpdatePluginModal, @@ -124,7 +140,6 @@ describe('usePluginOperations', () => { modalStates = createModalStatesMock() versionPicker = createVersionPickerMock() mockOnUpdate = vi.fn() - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {}) }) @@ -233,7 +248,7 @@ describe('usePluginOperations', () => { }) expect(mockCheckForUpdates).toHaveBeenCalled() - expect(Toast.notify).toHaveBeenCalled() + expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: 'Update available' }) }) it('should show update plugin modal when update is needed', async () => { diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts index bf6bb4aae6..ade47cec5f 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts @@ -5,7 +5,7 @@ import type { ModalStates, VersionTarget } from './use-detail-header-state' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { uninstallPlugin } from '@/service/plugins' @@ -60,10 +60,7 @@ export const usePluginOperations = ({ } if (!meta?.repo || !meta?.version || !meta?.package) { - Toast.notify({ - type: 'error', - message: 'Missing plugin metadata for GitHub update', - }) + toast.error('Missing plugin metadata for GitHub update') return } @@ -74,7 +71,7 @@ export const usePluginOperations = ({ return const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta.version) - Toast.notify(toastProps) + toast(toastProps.message, { type: toastProps.type }) if (needUpdate) { setShowUpdatePluginModal({ @@ -122,10 +119,7 @@ export const usePluginOperations = ({ if (res.success) { modalStates.hideDeleteConfirm() - Toast.notify({ - type: 'success', - message: t('action.deleteSuccess', { ns: 'plugin' }), - }) + toast.success(t('action.deleteSuccess', { ns: 'plugin' })) handlePluginUpdated(true) if (PluginCategoryEnum.model.includes(category)) diff --git a/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx b/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx index 164bab0f04..9f95d9c7e1 100644 --- a/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx +++ b/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx @@ -9,8 +9,8 @@ import ActionButton from '@/app/components/base/action-button' import Confirm from '@/app/components/base/confirm' import { CopyCheck } from '@/app/components/base/icons/src/vender/line/files' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import Indicator from '@/app/components/header/indicator' import { addDefaultValue, toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { @@ -47,7 +47,7 @@ const EndpointCard = ({ await handleChange() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) setActive(false) }, }) @@ -57,7 +57,7 @@ const EndpointCard = ({ hideDisableConfirm() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) setActive(false) }, }) @@ -83,7 +83,7 @@ const EndpointCard = ({ hideDeleteConfirm() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) }, }) @@ -108,7 +108,7 @@ const EndpointCard = ({ hideEndpointModalConfirm() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) }, }) const handleUpdate = (state: Record) => updateEndpoint({ @@ -139,7 +139,7 @@ const EndpointCard = ({
-
+
{data.name}
@@ -154,8 +154,8 @@ const EndpointCard = ({
{data.declaration.endpoints.filter(endpoint => !endpoint.hidden).map((endpoint, index) => (
-
{endpoint.method}
-
+
{endpoint.method}
+
{`${data.url}${endpoint.path}`}
handleCopy(`${data.url}${endpoint.path}`)}> @@ -168,13 +168,13 @@ const EndpointCard = ({
{active && ( -
+
{t('detailPanel.serviceOk', { ns: 'plugin' })}
)} {!active && ( -
+
{t('detailPanel.disabled', { ns: 'plugin' })}
diff --git a/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx b/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx index 357e714ba2..366139d12d 100644 --- a/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx +++ b/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx @@ -9,8 +9,8 @@ import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { useDocLink } from '@/context/i18n' import { @@ -50,7 +50,7 @@ const EndpointList = ({ detail }: Props) => { hideEndpointModal() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) }, }) @@ -64,7 +64,7 @@ const EndpointList = ({ detail }: Props) => { return (
-
+
{t('detailPanel.endpoints', { ns: 'plugin' })} {
-
{t('detailPanel.endpointsTip', { ns: 'plugin' })}
+
{t('detailPanel.endpointsTip', { ns: 'plugin' })}
-
+
{t('detailPanel.endpointsDocLink', { ns: 'plugin' })}
@@ -95,7 +95,7 @@ const EndpointList = ({ detail }: Props) => {
{data.endpoints.length === 0 && ( -
{t('detailPanel.endpointsEmpty', { ns: 'plugin' })}
+
{t('detailPanel.endpointsEmpty', { ns: 'plugin' })}
)}
{data.endpoints.map((item, index) => ( diff --git a/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx b/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx index 929e990f90..4d93e14c8b 100644 --- a/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import Button from '@/app/components/base/button' import Drawer from '@/app/components/base/drawer' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import { useRenderI18nObject } from '@/hooks/use-i18n' import { cn } from '@/utils/classnames' @@ -48,7 +48,10 @@ const EndpointModal: FC = ({ const handleSave = () => { for (const field of formSchemas) { if (field.required && !tempCredential[field.name]) { - Toast.notify({ type: 'error', message: t('errorMsg.fieldRequired', { ns: 'common', field: typeof field.label === 'string' ? field.label : getValueFromI18nObject(field.label as Record) }) }) + toast.error(t('errorMsg.fieldRequired', { + ns: 'common', + field: typeof field.label === 'string' ? field.label : getValueFromI18nObject(field.label as Record), + })) return } } @@ -83,12 +86,12 @@ const EndpointModal: FC = ({ <>
-
{t('detailPanel.endpointModalTitle', { ns: 'plugin' })}
+
{t('detailPanel.endpointModalTitle', { ns: 'plugin' })}
-
{t('detailPanel.endpointModalDesc', { ns: 'plugin' })}
+
{t('detailPanel.endpointModalDesc', { ns: 'plugin' })}
@@ -109,7 +112,7 @@ const EndpointModal: FC = ({ href={item.url} target="_blank" rel="noopener noreferrer" - className="body-xs-regular inline-flex items-center text-text-accent-secondary" + className="inline-flex items-center text-text-accent-secondary body-xs-regular" > {t('howToGet', { ns: 'tools' })} diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx index 9b04a710e0..107d42ada2 100644 --- a/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx @@ -1,14 +1,29 @@ import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -// Import component after mocks -import Toast from '@/app/components/base/toast' - import { ConfigurationMethodEnum, ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' + +// Import component after mocks import ModelParameterModal from '../index' // ==================== Mock Setup ==================== +const mockToastNotify = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + // Mock provider context const mockProviderContextValue = { isAPIKeySet: true, @@ -53,8 +68,6 @@ vi.mock('@/utils/completion-params', () => ({ fetchAndMergeValidCompletionParams: (...args: unknown[]) => mockFetchAndMergeValidCompletionParams(...args), })) -const mockToastNotify = vi.spyOn(Toast, 'notify') - // Mock child components vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ default: ({ defaultModel, modelList, scopeFeatures, onSelect }: { @@ -244,7 +257,6 @@ const setupModelLists = (config: { describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - mockToastNotify.mockReturnValue({}) mockProviderContextValue.isAPIKeySet = true mockProviderContextValue.modelProviders = [] setupModelLists() @@ -865,9 +877,7 @@ describe('ModelParameterModal', () => { // Assert await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'warning' }), - ) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'warning' })) }) }) @@ -892,9 +902,7 @@ describe('ModelParameterModal', () => { // Assert await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx index 04b78f98b7..6838bcec43 100644 --- a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx @@ -10,12 +10,12 @@ import type { import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' import { Popover, PopoverContent, PopoverTrigger, } from '@/app/components/base/ui/popover' +import { toast } from '@/app/components/base/ui/toast' import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList, @@ -134,14 +134,11 @@ const ModelParameterModal: FC = ({ const keys = Object.keys(removedDetails || {}) if (keys.length) { - Toast.notify({ - type: 'warning', - message: `${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`, - }) + toast.warning(`${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`) } } catch { - Toast.notify({ type: 'error', message: t('error', { ns: 'common' }) }) + toast.error(t('error', { ns: 'common' })) } } diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx index c6fb42faab..351c1f9d2d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx @@ -1,12 +1,26 @@ import type { TriggerLogEntity } from '@/app/components/workflow/block-selector/types' import { cleanup, fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import LogViewer from '../log-viewer' const mockToastNotify = vi.fn() const mockWriteText = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ default: ({ value }: { value: unknown }) => (
{JSON.stringify(value)}
@@ -57,10 +71,6 @@ beforeEach(() => { }, configurable: true, }) - vi.spyOn(Toast, 'notify').mockImplementation((args) => { - mockToastNotify(args) - return { clear: vi.fn() } - }) }) describe('LogViewer', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx index d8d41ff9b2..3c4ff83fc8 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx @@ -26,10 +26,16 @@ vi.mock('@/service/use-triggers', () => ({ useDeleteTriggerSubscription: () => ({ mutate: vi.fn(), isPending: false }), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx index 83d0cdd89d..44cec53e28 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx @@ -1,7 +1,6 @@ import type { TriggerSubscription } from '@/app/components/workflow/block-selector/types' import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { SubscriptionSelectorView } from '../selector-view' @@ -26,6 +25,18 @@ vi.mock('@/service/use-triggers', () => ({ useDeleteTriggerSubscription: () => ({ mutate: mockDelete, isPending: false }), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', @@ -42,7 +53,6 @@ const createSubscription = (overrides: Partial = {}): Trigg beforeEach(() => { vi.clearAllMocks() mockSubscriptions = [createSubscription()] - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('SubscriptionSelectorView', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx index a51bc2954f..4665c921ca 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx @@ -1,7 +1,6 @@ import type { TriggerSubscription } from '@/app/components/workflow/block-selector/types' import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import SubscriptionCard from '../subscription-card' @@ -30,6 +29,18 @@ vi.mock('@/service/use-triggers', () => ({ useDeleteTriggerSubscription: () => ({ mutate: vi.fn(), isPending: false }), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', @@ -45,7 +56,6 @@ const createSubscription = (overrides: Partial = {}): Trigg beforeEach(() => { vi.clearAllMocks() - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('SubscriptionCard', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx index 21a4c3defa..72532ea38d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx @@ -122,10 +122,16 @@ vi.mock('@/utils/urlValidation', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: unknown) => mockToastNotify(params), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((params: unknown) => mockToastNotify(params), { + success: (message: unknown) => mockToastNotify({ type: 'success', message }), + error: (message: unknown) => mockToastNotify({ type: 'error', message }), + warning: (message: unknown) => mockToastNotify({ type: 'warning', message }), + info: (message: unknown) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) vi.mock('@/app/components/base/modal/modal', () => ({ diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx index 3fe9884b92..a36c108160 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx @@ -2,6 +2,7 @@ import type { SimpleDetail } from '../../../store' import type { TriggerOAuthConfig, TriggerProviderApiEntity, TriggerSubscription, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { toast } from '@/app/components/base/ui/toast' import { SupportedCreationMethods } from '@/app/components/plugins/types' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { CreateButtonType, CreateSubscriptionButton, DEFAULT_METHOD } from '../index' @@ -33,10 +34,16 @@ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) let mockStoreDetail: SimpleDetail | undefined @@ -908,8 +915,6 @@ describe('CreateSubscriptionButton', () => { it('should handle OAuth initiation error', async () => { // Arrange - const Toast = await import('@/app/components/base/toast') - mockInitiateOAuth.mockImplementation((_provider: string, callbacks: { onError: () => void }) => { callbacks.onError() }) @@ -932,9 +937,7 @@ describe('CreateSubscriptionButton', () => { // Assert await waitFor(() => { - expect(Toast.default.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(toast.error).toHaveBeenCalled() }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx index 12419a9bf3..ce53bf5b9a 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx @@ -86,10 +86,19 @@ vi.mock('@/hooks/use-oauth', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: unknown) => mockToastNotify(params), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), })) const mockClipboardWriteText = vi.fn() diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts index 89566f3af7..68864b0b80 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts @@ -77,10 +77,19 @@ vi.mock('@/hooks/use-oauth', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: unknown) => mockToastNotify(params), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), })) // ============================================================================ diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts index b01312d3d1..99c42f6fc5 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts @@ -7,7 +7,7 @@ import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers' import { debounce } from 'es-toolkit/compat' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { SupportedCreationMethods } from '@/app/components/plugins/types' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { @@ -154,10 +154,7 @@ export const useCommonModalState = ({ onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' }) console.error('Failed to update subscription builder:', error) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) @@ -178,10 +175,7 @@ export const useCommonModalState = ({ } catch (error) { console.error('createBuilder error:', error) - Toast.notify({ - type: 'error', - message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }), - }) + toast.error(t('modal.errors.createFailed', { ns: 'pluginTrigger' })) } } if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider) @@ -239,10 +233,7 @@ export const useCommonModalState = ({ const handleVerify = useCallback(() => { // Guard against uninitialized state if (!detail?.provider || !subscriptionBuilder?.id) { - Toast.notify({ - type: 'error', - message: 'Subscription builder not initialized', - }) + toast.error('Subscription builder not initialized') return } @@ -250,10 +241,7 @@ export const useCommonModalState = ({ const credentials = apiKeyCredentialsFormValues.values if (!Object.keys(credentials).length) { - Toast.notify({ - type: 'error', - message: 'Please fill in all required credentials', - }) + toast.error('Please fill in all required credentials') return } @@ -270,10 +258,7 @@ export const useCommonModalState = ({ }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.apiKey.verify.success', { ns: 'pluginTrigger' })) setCurrentStep(ApiKeyStep.Configuration) }, onError: async (error: unknown) => { @@ -290,10 +275,7 @@ export const useCommonModalState = ({ // Handle create const handleCreate = useCallback(() => { if (!subscriptionBuilder) { - Toast.notify({ - type: 'error', - message: 'Subscription builder not found', - }) + toast.error('Subscription builder not found') return } @@ -327,19 +309,13 @@ export const useCommonModalState = ({ params, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.createSuccess', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.createSuccess', { ns: 'pluginTrigger' })) onClose() refetch?.() }, onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts index 6a551051e2..e5a5ded9df 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts @@ -4,7 +4,7 @@ import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionB import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { openOAuthPopup } from '@/hooks/use-oauth' import { useConfigureTriggerOAuth, @@ -118,20 +118,14 @@ export const useOAuthClientState = ({ openOAuthPopup(response.authorization_url, (callbackData) => { if (!callbackData) return - Toast.notify({ - type: 'success', - message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' })) onClose() showOAuthCreateModal(response.subscription_builder) }) }, onError: () => { setAuthorizationStatus(AuthorizationStatusEnum.Failed) - Toast.notify({ - type: 'error', - message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }), - }) + toast.error(t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' })) }, }) }, [providerName, initiateOAuth, onClose, showOAuthCreateModal, t]) @@ -141,16 +135,10 @@ export const useOAuthClientState = ({ deleteOAuth(providerName, { onSuccess: () => { onClose() - Toast.notify({ - type: 'success', - message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.remove.success', { ns: 'pluginTrigger' })) }, onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })), - }) + toast.error(getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' }))) }, }) }, [providerName, deleteOAuth, onClose, t]) @@ -187,10 +175,7 @@ export const useOAuthClientState = ({ return } onClose() - Toast.notify({ - type: 'success', - message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.save.success', { ns: 'pluginTrigger' })) }, }) }, [clientType, providerName, oauthClientSchema, oauthConfig?.params, configureOAuth, handleAuthorization, onClose, t]) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx index eecaf165fb..bd0846c15e 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx @@ -8,8 +8,8 @@ import { ActionButton, ActionButtonState } from '@/app/components/base/action-bu import Badge from '@/app/components/base/badge' import { Button } from '@/app/components/base/button' import CustomSelect from '@/app/components/base/select/custom' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { openOAuthPopup } from '@/hooks/use-oauth' import { useInitiateTriggerOAuth, useTriggerOAuthConfig, useTriggerProviderInfo } from '@/service/use-triggers' import { cn } from '@/utils/classnames' @@ -107,19 +107,13 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU onSuccess: (response) => { openOAuthPopup(response.authorization_url, (callbackData) => { if (callbackData) { - Toast.notify({ - type: 'success', - message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' })) setSelectedCreateInfo({ type: SupportedCreationMethods.OAUTH, builder: response.subscription_builder }) } }) }, onError: () => { - Toast.notify({ - type: 'error', - message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }), - }) + toast.error(t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' })) }, }) } diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx index b7f9b8ebec..d4bc92169c 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { BaseForm } from '@/app/components/base/form/components/base' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' import { usePluginStore } from '../../store' import { ClientTypeEnum, useOAuthClientState } from './hooks/use-oauth-client-state' @@ -48,10 +48,7 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate const handleCopyRedirectUri = () => { navigator.clipboard.writeText(oauthConfig?.redirect_uri || '') - Toast.notify({ - type: 'success', - message: t('actionMsg.copySuccessfully', { ns: 'common' }), - }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) } return ( diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx index e0fb7455ce..7f22a06295 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx @@ -47,13 +47,19 @@ vi.mock('@/service/use-triggers', () => ({ useTriggerPluginDynamicOptions: () => ({ data: [], isLoading: false }), })) -vi.mock('@/app/components/base/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, - default: { - notify: (args: { type: string, message: string }) => mockToast(args), - }, + toast: Object.assign((args: { type: string, message: string }) => mockToast(args), { + success: (message: string) => mockToast({ type: 'success', message }), + error: (message: string) => mockToast({ type: 'error', message }), + warning: (message: string) => mockToast({ type: 'warning', message }), + info: (message: string) => mockToast({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), } }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx index 7d188a3f6d..126d8e366d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx @@ -13,8 +13,16 @@ import { OAuthEditModal } from '../oauth-edit-modal' // ==================== Mock Setup ==================== const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (params: unknown) => mockToastNotify(params) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockParsePluginErrorMessage = vi.fn() diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx index 60a8428287..52572b3560 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx @@ -30,13 +30,19 @@ vi.mock('@/service/use-triggers', () => ({ useTriggerPluginDynamicOptions: () => ({ data: [], isLoading: false }), })) -vi.mock('@/app/components/base/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, - default: { - notify: (args: { type: string, message: string }) => mockToast(args), - }, + toast: Object.assign((args: { type: string, message: string }) => mockToast(args), { + success: (message: string) => mockToast({ type: 'success', message }), + error: (message: string) => mockToast({ type: 'error', message }), + warning: (message: string) => mockToast({ type: 'warning', message }), + info: (message: string) => mockToast({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), } }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx index 8835b46695..95b2cca6af 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx @@ -30,13 +30,19 @@ vi.mock('@/service/use-triggers', () => ({ useTriggerPluginDynamicOptions: () => ({ data: [], isLoading: false }), })) -vi.mock('@/app/components/base/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, - default: { - notify: (args: { type: string, message: string }) => mockToast(args), - }, + toast: Object.assign((args: { type: string, message: string }) => mockToast(args), { + success: (message: string) => mockToast({ type: 'success', message }), + error: (message: string) => mockToast({ type: 'error', message }), + warning: (message: string) => mockToast({ type: 'warning', message }), + info: (message: string) => mockToast({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), } }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx index a4093ed00b..247beaa626 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx @@ -9,7 +9,7 @@ import { EncryptedBottom } from '@/app/components/base/encrypted-bottom' import { BaseForm } from '@/app/components/base/form/components/base' import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance' import { useUpdateTriggerSubscription, useVerifyTriggerSubscription } from '@/service/use-triggers' import { parsePluginErrorMessage } from '@/utils/error-parser' @@ -65,7 +65,7 @@ const StatusStep = ({ isActive, text, onClick, clickable }: { }) => { return (
{ - Toast.notify({ - type: 'success', - message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.apiKey.verify.success', { ns: 'pluginTrigger' })) // Only save credentials if any field was modified (not all hidden) setVerifiedCredentials(areAllCredentialsHidden(credentials) ? null : credentials) setCurrentStep(EditStep.EditConfiguration) }, onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) @@ -192,19 +186,13 @@ export const ApiKeyEditModal = ({ onClose, subscription, pluginDetail }: Props) }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' })) refetch?.() onClose() }, onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx index 262235e6ed..e1741da8e7 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import { BaseForm } from '@/app/components/base/form/components/base' import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance' import { useUpdateTriggerSubscription } from '@/service/use-triggers' import { ReadmeShowType } from '../../../readme-panel/store' @@ -94,18 +94,12 @@ export const ManualEditModal = ({ onClose, subscription, pluginDetail }: Props) }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' })) refetch?.() onClose() }, onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' })), - }) + toast.error(getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' }))) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx index e57b9c0151..c43a00a322 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import { BaseForm } from '@/app/components/base/form/components/base' import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance' import { useUpdateTriggerSubscription } from '@/service/use-triggers' import { ReadmeShowType } from '../../../readme-panel/store' @@ -94,18 +94,12 @@ export const OAuthEditModal = ({ onClose, subscription, pluginDetail }: Props) = }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' })) refetch?.() onClose() }, onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' })), - }) + toast.error(getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' }))) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx index 3b4edd1b85..6ccab000c4 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx @@ -11,7 +11,7 @@ import dayjs from 'dayjs' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { cn } from '@/utils/classnames' @@ -89,10 +89,7 @@ const LogViewer = ({ logs, className }: Props) => { onClick={(e) => { e.stopPropagation() navigator.clipboard.writeText(String(parsedData)) - Toast.notify({ - type: 'success', - message: t('actionMsg.copySuccessfully', { ns: 'common' }), - }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} className="rounded-md p-0.5 hover:bg-components-panel-border" > diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx index 26e4de0fd7..537e99d733 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx @@ -298,8 +298,16 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-modal // Mock Toast - need to track notify calls for assertions const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) // ==================== Test Utilities ==================== @@ -1943,7 +1951,7 @@ describe('ToolCredentialsForm Component', () => { const saveBtn = screen.getByText(/save/i) fireEvent.click(saveBtn) - // Toast.notify should have been called with error (lines 49-50) + // notifyToast should have been called with error (lines 49-50) expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) // onSaved should not be called because validation fails expect(onSaved).not.toHaveBeenCalled() diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx index cb5b929d29..e1b8ca86fe 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx @@ -10,12 +10,16 @@ vi.mock('@/utils/classnames', () => ({ cn: (...args: unknown[]) => args.filter(Boolean).join(' '), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, -})) - -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: vi.fn() }), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockFormSchemas = [ diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx index 0207f65336..f4d39c9ec1 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx @@ -10,7 +10,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import { addDefaultValue, toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { useRenderI18nObject } from '@/hooks/use-i18n' @@ -49,7 +49,10 @@ const ToolCredentialForm: FC = ({ return for (const field of credentialSchema) { if (field.required && !tempCredential[field.name]) { - Toast.notify({ type: 'error', message: t('errorMsg.fieldRequired', { ns: 'common', field: getValueFromI18nObject(field.label) }) }) + toast.error(t('errorMsg.fieldRequired', { + ns: 'common', + field: getValueFromI18nObject(field.label), + })) return } } diff --git a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx index 8467c983d8..b4d21c9403 100644 --- a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx +++ b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx @@ -1,7 +1,6 @@ import type { MetaData, PluginCategoryEnum } from '../../types' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' // ==================== Imports (after mocks) ==================== @@ -17,12 +16,29 @@ const { mockCheckForUpdates, mockSetShowUpdatePluginModal, mockInvalidateInstalledPluginList, + mockToastNotify, } = vi.hoisted(() => ({ mockUninstallPlugin: vi.fn(), mockFetchReleases: vi.fn(), mockCheckForUpdates: vi.fn(), mockSetShowUpdatePluginModal: vi.fn(), mockInvalidateInstalledPluginList: vi.fn(), + mockToastNotify: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), })) // Mock uninstall plugin service @@ -140,13 +156,8 @@ const getActionButtons = () => screen.getAllByRole('button') const queryActionButtons = () => screen.queryAllByRole('button') describe('Action Component', () => { - // Spy on Toast.notify - real component but we track calls - let toastNotifySpy: ReturnType - beforeEach(() => { vi.clearAllMocks() - // Spy on Toast.notify and mock implementation to avoid DOM side effects - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) mockUninstallPlugin.mockResolvedValue({ success: true }) mockFetchReleases.mockResolvedValue([]) mockCheckForUpdates.mockReturnValue({ @@ -155,10 +166,6 @@ describe('Action Component', () => { }) }) - afterEach(() => { - toastNotifySpy.mockRestore() - }) - // ==================== Rendering Tests ==================== describe('Rendering', () => { it('should render delete button when isShowDelete is true', () => { @@ -563,9 +570,9 @@ describe('Action Component', () => { render() fireEvent.click(getActionButtons()[0]) - // Assert - Toast.notify is called with the toast props + // Assert - toast is called with the translated payload await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'Already up to date' }) + expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: 'Already up to date' }) }) }) diff --git a/web/app/components/plugins/plugin-item/action.tsx b/web/app/components/plugins/plugin-item/action.tsx index 171e54acab..413b41e895 100644 --- a/web/app/components/plugins/plugin-item/action.tsx +++ b/web/app/components/plugins/plugin-item/action.tsx @@ -7,7 +7,7 @@ import { useBoolean } from 'ahooks' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useModalContext } from '@/context/modal-context' import { uninstallPlugin } from '@/service/plugins' import { useInvalidateInstalledPluginList } from '@/service/use-plugins' @@ -65,7 +65,7 @@ const Action: FC = ({ if (fetchedReleases.length === 0) return const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version) - Toast.notify(toastProps) + toast(toastProps.message, { type: toastProps.type }) if (needUpdate) { setShowUpdatePluginModal({ onSaveCallback: () => { diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 04e8be3afd..0f69e6ce33 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -5062,9 +5062,6 @@ } }, "app/components/plugins/install-plugin/hooks.ts": { - "no-restricted-imports": { - "count": 2 - }, "ts/no-explicit-any": { "count": 4 } @@ -5100,9 +5097,6 @@ }, "app/components/plugins/install-plugin/install-from-github/index.tsx": { "no-restricted-imports": { - "count": 3 - }, - "tailwindcss/enforce-consistent-class-order": { "count": 2 }, "ts/no-explicit-any": { @@ -5367,17 +5361,9 @@ "count": 1 } }, - "app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/plugins/plugin-detail-panel/endpoint-card.tsx": { "no-restricted-imports": { - "count": 3 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 5 + "count": 2 }, "ts/no-explicit-any": { "count": 2 @@ -5385,22 +5371,13 @@ }, "app/components/plugins/plugin-detail-panel/endpoint-list.tsx": { "no-restricted-imports": { - "count": 2 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 4 + "count": 1 }, "ts/no-explicit-any": { "count": 2 } }, "app/components/plugins/plugin-detail-panel/endpoint-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 3 - }, "ts/no-explicit-any": { "count": 7 } @@ -5414,9 +5391,6 @@ } }, "app/components/plugins/plugin-detail-panel/model-selector/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -5471,27 +5445,21 @@ "app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts": { "erasable-syntax-only/enums": { "count": 1 - }, - "no-restricted-imports": { - "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts": { "erasable-syntax-only/enums": { "count": 2 - }, - "no-restricted-imports": { - "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { "no-restricted-imports": { - "count": 4 + "count": 3 } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -5507,20 +5475,17 @@ "count": 1 }, "no-restricted-imports": { - "count": 2 - }, - "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": { @@ -5540,9 +5505,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -5600,11 +5562,6 @@ "count": 2 } }, - "app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/plugins/plugin-detail-panel/tool-selector/components/tool-item.tsx": { "no-restricted-imports": { "count": 2 @@ -5643,7 +5600,7 @@ }, "app/components/plugins/plugin-item/action.tsx": { "no-restricted-imports": { - "count": 3 + "count": 2 } }, "app/components/plugins/plugin-item/index.tsx": { From 508350ec6ac89c7764a145ef61cad7e3d1f39ad6 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:19:32 +0800 Subject: [PATCH 19/32] test: enhance useChat hook tests with additional scenarios (#33928) Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../__tests__/hooks/handle-resume.spec.ts | 1090 +++++++++++++++++ .../__tests__/hooks/handle-send.spec.ts | 194 +++ .../hooks/handle-stop-restart.spec.ts | 199 +++ .../__tests__/hooks/misc.spec.ts | 380 ++++++ .../opening-statement.spec.ts} | 69 +- .../__tests__/hooks/sse-callbacks.spec.ts | 914 ++++++++++++++ 6 files changed, 2831 insertions(+), 15 deletions(-) create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts rename web/app/components/workflow/panel/debug-and-preview/__tests__/{hooks.spec.ts => hooks/opening-statement.spec.ts} (63%) create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts new file mode 100644 index 0000000000..3367bd4766 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts @@ -0,0 +1,1090 @@ +/* eslint-disable ts/no-explicit-any */ +import type { ChatItemInTree } from '@/app/components/base/chat/types' +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleResume', () => { + let capturedResumeOptions: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockSseGet.mockReset() + }) + + async function setupResumeWithTree() { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => useChat({})) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', {}) + }) + + return hook + } + + it('should call sseGet with the correct URL', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleResume('msg-1', 'wfr-1', {}) + }) + + expect(mockSseGet).toHaveBeenCalledWith( + '/workflow/wfr-1/events?include_state_snapshot=true', + {}, + expect.any(Object), + ) + }) + + it('should abort previous SSE connection when handleResume is called again', () => { + const mockAbortCtrl = new AbortController() + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + options.getAbortController(mockAbortCtrl) + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleResume('msg-1', 'wfr-1', {}) + }) + + const mockAbort2 = vi.fn() + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + options.getAbortController({ abort: mockAbort2 }) + }) + + act(() => { + result.current.handleResume('msg-1', 'wfr-2', {}) + }) + + expect(mockAbortCtrl.signal.aborted).toBe(true) + }) + + it('should abort previous workflowEventsAbortController before sseGet', () => { + const mockAbort = vi.fn() + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.getAbortController({ abort: mockAbort } as any) + }) + + mockSseGet.mockImplementation(() => {}) + + act(() => { + result.current.handleResume('msg-1', 'wfr-1', {}) + }) + + expect(mockAbort).toHaveBeenCalledTimes(1) + }) + + describe('onWorkflowStarted', () => { + it('should set isResponding and update workflow process', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + expect(result.current.isResponding).toBe(true) + }) + + it('should resume existing workflow when tracing exists', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => useChat({})) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + act(() => { + sendCallbacks.onNodeStarted({ + data: { node_id: 'n1', id: 'trace-1' }, + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', {}) + }) + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + const answer = hook.result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('running') + }) + }) + + describe('onWorkflowFinished', () => { + it('should update workflow status', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onWorkflowFinished({ + data: { status: 'succeeded' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('succeeded') + }) + }) + + describe('onData', () => { + it('should append message content', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onData('resumed', false, { + conversationId: 'conv-2', + messageId: 'msg-resume', + taskId: 'task-2', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.content).toContain('resumed') + }) + + it('should update conversationId when provided', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onData('msg', false, { + conversationId: 'new-conv-resume', + messageId: null, + taskId: 'task-2', + }) + }) + + expect(result.current.conversationId).toBe('new-conv-resume') + }) + }) + + describe('onCompleted', () => { + it('should set isResponding to false', async () => { + const { result } = await setupResumeWithTree() + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should not call fetchInspectVars when paused', async () => { + mockWorkflowRunningData = { result: { status: 'paused' } } + await setupResumeWithTree() + mockFetchInspectVars.mockClear() + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + expect(mockFetchInspectVars).not.toHaveBeenCalled() + }) + + it('should still call fetchInspectVars on error but skip suggested questions', async () => { + const mockGetSuggested = vi.fn().mockResolvedValue({ data: ['s1'] }) + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + mockFetchInspectVars.mockClear() + mockInvalidAllLastRun.mockClear() + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + await act(async () => { + await capturedResumeOptions.onCompleted(true) + }) + + expect(mockFetchInspectVars).toHaveBeenCalledWith({}) + expect(mockInvalidAllLastRun).toHaveBeenCalled() + expect(mockGetSuggested).not.toHaveBeenCalled() + }) + + it('should fetch suggested questions when enabled', async () => { + const mockGetSuggested = vi.fn().mockImplementation((_id: string, getAbortCtrl: any) => { + getAbortCtrl(new AbortController()) + return Promise.resolve({ data: ['s1'] }) + }) + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + + expect(mockGetSuggested).toHaveBeenCalled() + }) + + it('should set suggestedQuestions to empty on fetch error', async () => { + const mockGetSuggested = vi.fn().mockRejectedValue(new Error('fail')) + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + + expect(hook.result.current.suggestedQuestions).toEqual([]) + }) + }) + + describe('onError', () => { + it('should set isResponding to false', async () => { + const { result } = await setupResumeWithTree() + act(() => { + capturedResumeOptions.onError() + }) + expect(result.current.isResponding).toBe(false) + }) + }) + + describe('onMessageEnd', () => { + it('should update citation and files', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onMessageEnd({ + metadata: { retriever_resources: [{ id: 'cite-1' }] }, + files: [], + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.citation).toEqual([{ id: 'cite-1' }]) + }) + }) + + describe('onMessageReplace', () => { + it('should replace content', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onMessageReplace({ answer: 'replaced' }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.content).toBe('replaced') + }) + }) + + describe('onIterationStart / onIterationFinish', () => { + it('should push and update iteration tracing entries', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onIterationStart({ + data: { id: 'iter-r1', node_id: 'n-iter-r' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('iter-r1') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + + act(() => { + capturedResumeOptions.onIterationFinish({ + data: { id: 'iter-r1', node_id: 'n-iter-r', execution_metadata: {} }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].status).toBe('succeeded') + }) + + it('should handle iteration finish when no match found', async () => { + await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onIterationFinish({ + data: { id: 'no-match', node_id: 'no-match', execution_metadata: {} }, + }) + }) + }) + }) + + describe('onLoopStart / onLoopFinish', () => { + it('should push and update loop tracing entries', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onLoopStart({ + data: { id: 'loop-r1', node_id: 'n-loop-r' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('loop-r1') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + + act(() => { + capturedResumeOptions.onLoopFinish({ + data: { id: 'loop-r1', node_id: 'n-loop-r', execution_metadata: {} }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].status).toBe('succeeded') + }) + + it('should handle loop finish when no match found', async () => { + await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onLoopFinish({ + data: { id: 'no-match', node_id: 'no-match', execution_metadata: {} }, + }) + }) + }) + }) + + describe('onNodeStarted / onNodeFinished', () => { + it('should add and update node tracing entries', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + const startedTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'rn-1') + expect(startedTrace).toBeDefined() + expect(startedTrace!.id).toBe('rtrace-1') + expect(startedTrace!.status).toBe('running') + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { node_id: 'rn-1', id: 'rtrace-1', status: 'succeeded' }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + const finishedTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'rn-1') + expect(finishedTrace).toBeDefined() + expect((finishedTrace as any).status).toBe('succeeded') + }) + + it('should skip onNodeStarted when iteration_id is present', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-child', id: 'rtrace-child', iteration_id: 'iter-parent' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing.some((t: any) => t.node_id === 'rn-child')).toBe(false) + }) + + it('should skip onNodeFinished when iteration_id is present', async () => { + await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { node_id: 'rn-1', id: 'rtrace-1', iteration_id: 'iter-parent' }, + }) + }) + }) + + it('should update existing node in tracing on onNodeStarted', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1-v2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const matchingTraces = answer!.workflowProcess!.tracing.filter((t: any) => t.node_id === 'rn-1') + expect(matchingTraces).toHaveLength(1) + expect(matchingTraces[0].id).toBe('rtrace-1-v2') + expect(matchingTraces[0].status).toBe('running') + }) + + it('should match nodeFinished with parallel_id', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1', execution_metadata: { parallel_id: 'p1' } }, + }) + }) + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { + node_id: 'rn-1', + id: 'rtrace-1', + status: 'succeeded', + execution_metadata: { parallel_id: 'p1' }, + }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.id === 'rtrace-1') + expect(trace).toBeDefined() + expect((trace as any).status).toBe('succeeded') + expect((trace as any).execution_metadata.parallel_id).toBe('p1') + }) + }) + + describe('onHumanInputRequired', () => { + it('should initialize humanInputFormDataList', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(1) + }) + + it('should update existing form for same node and push for different node', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-2' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(1) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human-2', form_token: 'rt-3' }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(2) + }) + + it('should set tracing node to Paused when tracing match is found', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-human', id: 'trace-human' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'rn-human') + expect(trace!.status).toBe('paused') + }) + }) + + describe('onHumanInputFormFilled', () => { + it('should move form from pending to filled list', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputFormFilled({ + data: { node_id: 'rn-human', form_data: { a: 1 } }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(0) + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + }) + + it('should initialize humanInputFilledFormDataList when not present', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputFormFilled({ + data: { node_id: 'rn-human', form_data: { b: 2 } }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + }) + }) + + describe('onHumanInputFormTimeout', () => { + it('should set expiration_time on the form entry', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputFormTimeout({ + data: { node_id: 'rn-human', expiration_time: '2025-06-01' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const form = answer!.humanInputFormDataList!.find((f: any) => f.node_id === 'rn-human') + expect(form!.expiration_time).toBe('2025-06-01') + }) + }) + + describe('onWorkflowPaused', () => { + it('should re-subscribe via sseGet and set status to Paused', async () => { + const { result } = await setupResumeWithTree() + const sseGetCallsBefore = mockSseGet.mock.calls.length + + act(() => { + capturedResumeOptions.onWorkflowPaused({ + data: { workflow_run_id: 'wfr-paused' }, + }) + }) + + expect(mockSseGet.mock.calls.length).toBeGreaterThan(sseGetCallsBefore) + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('paused') + }) + }) +}) + +describe('useChat – handleResume with bare prevChatTree (no humanInputFormDataList / no tracing)', () => { + let capturedResumeOptions: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockSseGet.mockReset() + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + }) + + function setupWithBareTree() { + const prevChatTree: ChatItemInTree[] = [ + { + id: 'q1', + content: 'question', + isAnswer: false, + children: [ + { + id: 'bare-msg', + content: '', + isAnswer: true, + workflow_run_id: 'wfr-bare', + workflowProcess: { + status: 'running' as any, + tracing: [], + }, + children: [], + }, + ], + }, + ] + + const hook = renderHook(() => useChat({}, undefined, prevChatTree)) + + act(() => { + hook.result.current.handleResume('bare-msg', 'wfr-bare', {}) + }) + + return hook + } + + function setupWithBareTreeNoTracing() { + const prevChatTree: ChatItemInTree[] = [ + { + id: 'q1', + content: 'question', + isAnswer: false, + children: [ + { + id: 'bare-msg-nt', + content: '', + isAnswer: true, + workflow_run_id: 'wfr-bare-nt', + workflowProcess: { + status: 'running' as any, + tracing: undefined as any, + }, + children: [], + }, + ], + }, + ] + + const hook = renderHook(() => useChat({}, undefined, prevChatTree)) + + act(() => { + hook.result.current.handleResume('bare-msg-nt', 'wfr-bare-nt', {}) + }) + + return hook + } + + it('onHumanInputRequired should initialize humanInputFormDataList when null', () => { + const { result } = setupWithBareTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'hn-bare', form_token: 'ft-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg') + expect(answer!.humanInputFormDataList).toHaveLength(1) + }) + + it('onHumanInputFormFilled should initialize humanInputFilledFormDataList when null', () => { + const { result } = setupWithBareTree() + + act(() => { + capturedResumeOptions.onHumanInputFormFilled({ + data: { node_id: 'hn-bare', form_data: { x: 1 } }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg') + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + }) + + it('onLoopStart should initialize tracing array when not present', () => { + const { result } = setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onLoopStart({ + data: { id: 'loop-bare', node_id: 'n-loop-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg-nt') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('loop-bare') + expect(answer!.workflowProcess!.tracing[0].node_id).toBe('n-loop-bare') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + }) + + it('onLoopFinish should return early when no tracing', () => { + setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onLoopFinish({ + data: { id: 'loop-bare', node_id: 'n-loop-bare', execution_metadata: {} }, + }) + }) + }) + + it('onIterationStart should initialize tracing when not present', () => { + const { result } = setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onIterationStart({ + data: { id: 'iter-bare', node_id: 'n-iter-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg-nt') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('iter-bare') + expect(answer!.workflowProcess!.tracing[0].node_id).toBe('n-iter-bare') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + }) + + it('onIterationFinish should return early when no tracing', () => { + setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onIterationFinish({ + data: { id: 'iter-bare', node_id: 'n-iter-bare', execution_metadata: {} }, + }) + }) + }) + + it('onNodeStarted should initialize tracing when not present', () => { + const { result } = setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-bare', id: 'rtrace-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg-nt') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('rtrace-bare') + expect(answer!.workflowProcess!.tracing[0].node_id).toBe('rn-bare') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + }) + + it('onNodeFinished should return early when no tracing', () => { + setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { node_id: 'rn-bare', id: 'rtrace-bare', status: 'succeeded' }, + }) + }) + }) + + it('onIterationStart/onNodeStarted/onLoopStart should return early when no workflowProcess', () => { + const prevChatTreeNoWP: ChatItemInTree[] = [ + { + id: 'q-nowp', + content: 'question', + isAnswer: false, + children: [ + { + id: 'bare-nowp', + content: '', + isAnswer: true, + children: [], + }, + ], + }, + ] + + const hook = renderHook(() => useChat({}, undefined, prevChatTreeNoWP)) + let opts: any + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + opts = options + }) + + act(() => { + hook.result.current.handleResume('bare-nowp', 'wfr-x', {}) + }) + + act(() => { + opts.onIterationStart({ data: { id: 'i1', node_id: 'ni1' } }) + }) + + act(() => { + opts.onNodeStarted({ data: { node_id: 'ns1', id: 'ts1' } }) + }) + + act(() => { + opts.onLoopStart({ data: { id: 'l1', node_id: 'nl1' } }) + }) + + const answer = hook.result.current.chatList.find(item => item.id === 'bare-nowp') + expect(answer!.workflowProcess).toBeUndefined() + }) + + it('onHumanInputRequired should set Paused on tracing node when found', () => { + const { result } = setupWithBareTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'hn-with-trace', id: 'trace-hn' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'hn-with-trace', form_token: 'ft-tr' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg') + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'hn-with-trace') + expect(trace!.status).toBe('paused') + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts new file mode 100644 index 0000000000..0a12aac582 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts @@ -0,0 +1,194 @@ +/* eslint-disable ts/no-explicit-any */ +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleSend', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + }) + + it('should call handleRun with processed params', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'hello', inputs: {} }, {}) + }) + + expect(mockHandleRun).toHaveBeenCalledTimes(1) + const [bodyParams] = mockHandleRun.mock.calls[0] + expect(bodyParams.query).toBe('hello') + }) + + it('should show notification and return false when already responding', () => { + mockHandleRun.mockImplementation(() => {}) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'first' }, {}) + }) + + act(() => { + const returned = result.current.handleSend({ query: 'second' }, {}) + expect(returned).toBe(false) + }) + + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' })) + }) + + it('should set isResponding to true after sending', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleSend({ query: 'hello' }, {}) + }) + expect(result.current.isResponding).toBe(true) + }) + + it('should add placeholder question and answer to chatList', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test question' }, {}) + }) + + const questionItem = result.current.chatList.find(item => item.content === 'test question') + expect(questionItem).toBeDefined() + expect(questionItem!.isAnswer).toBe(false) + + const answerPlaceholder = result.current.chatList.find( + item => item.isAnswer && !item.isOpeningStatement && item.content === '', + ) + expect(answerPlaceholder).toBeDefined() + }) + + it('should strip url from local_file transfer method files', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend( + { + query: 'hello', + files: [ + { + id: 'f1', + name: 'test.png', + size: 1024, + type: 'image/png', + progress: 100, + transferMethod: 'local_file', + supportFileType: 'image', + url: 'blob://local', + uploadedId: 'up1', + }, + { + id: 'f2', + name: 'remote.png', + size: 2048, + type: 'image/png', + progress: 100, + transferMethod: 'remote_url', + supportFileType: 'image', + url: 'https://example.com/img.png', + uploadedId: '', + }, + ] as any, + }, + {}, + ) + }) + + expect(mockHandleRun).toHaveBeenCalledTimes(1) + const [bodyParams] = mockHandleRun.mock.calls[0] + const localFile = bodyParams.files.find((f: any) => f.transfer_method === 'local_file') + const remoteFile = bodyParams.files.find((f: any) => f.transfer_method === 'remote_url') + expect(localFile.url).toBe('') + expect(remoteFile.url).toBe('https://example.com/img.png') + }) + + it('should abort previous workflowEventsAbortController before sending', () => { + const mockAbort = vi.fn() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.getAbortController({ abort: mockAbort } as any) + callbacks.onCompleted(false) + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'first' }, {}) + }) + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.getAbortController({ abort: vi.fn() } as any) + }) + + act(() => { + result.current.handleSend({ query: 'second' }, {}) + }) + + expect(mockAbort).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts new file mode 100644 index 0000000000..7127634a10 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts @@ -0,0 +1,199 @@ +/* eslint-disable ts/no-explicit-any */ +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockStopChat = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleStop', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + }) + + it('should set isResponding to false', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleStop() + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should not call stopChat when taskId is empty even if stopChat is provided', () => { + const { result } = renderHook(() => useChat({}, undefined, undefined, mockStopChat)) + act(() => { + result.current.handleStop() + }) + expect(mockStopChat).not.toHaveBeenCalled() + }) + + it('should reset iter/loop times to defaults', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleStop() + }) + expect(mockSetIterTimes).toHaveBeenCalledWith(1) + expect(mockSetLoopTimes).toHaveBeenCalledWith(1) + }) + + it('should abort workflowEventsAbortController when set', () => { + const mockWfAbort = vi.fn() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.getAbortController({ abort: mockWfAbort } as any) + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + result.current.handleStop() + }) + + expect(mockWfAbort).toHaveBeenCalledTimes(1) + }) + + it('should abort suggestedQuestionsAbortController when set', async () => { + const mockSqAbort = vi.fn() + let capturedCb: any + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCb = callbacks + }) + + const mockGetSuggested = vi.fn().mockImplementation((_id: string, getAbortCtrl: any) => { + getAbortCtrl({ abort: mockSqAbort } as any) + return Promise.resolve({ data: ['s'] }) + }) + + const { result } = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedCb.onCompleted(false) + }) + + act(() => { + result.current.handleStop() + }) + + expect(mockSqAbort).toHaveBeenCalledTimes(1) + }) + + it('should call stopChat with taskId when both are available', () => { + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.onData('msg', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 'task-stop', + }) + }) + + const { result } = renderHook(() => useChat({}, undefined, undefined, mockStopChat)) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + result.current.handleStop() + }) + + expect(mockStopChat).toHaveBeenCalledWith('task-stop') + }) +}) + +describe('useChat – handleRestart', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + }) + + it('should clear suggestedQuestions and set isResponding to false', () => { + const config = { opening_statement: 'Hello' } + const { result } = renderHook(() => useChat(config)) + + act(() => { + result.current.handleRestart() + }) + + expect(result.current.suggestedQuestions).toEqual([]) + expect(result.current.isResponding).toBe(false) + }) + + it('should reset iter/loop times to defaults', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleRestart() + }) + expect(mockSetIterTimes).toHaveBeenCalledWith(1) + expect(mockSetLoopTimes).toHaveBeenCalledWith(1) + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts new file mode 100644 index 0000000000..fe2a800561 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts @@ -0,0 +1,380 @@ +/* eslint-disable ts/no-explicit-any */ +import type { ChatItemInTree } from '@/app/components/base/chat/types' +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleSwitchSibling', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockSseGet.mockReset() + }) + + it('should call handleResume when target has workflow_run_id and pending humanInputFormData', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation(() => {}) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-switch', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-switch', + }) + }) + + act(() => { + sendCallbacks.onHumanInputRequired({ + data: { node_id: 'human-n', form_token: 'ft-1' }, + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSwitchSibling('msg-switch', {}) + }) + + expect(mockSseGet).toHaveBeenCalled() + }) + + it('should not call handleResume when target has no humanInputFormDataList', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-switch', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-switch', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSwitchSibling('msg-switch', {}) + }) + + expect(mockSseGet).not.toHaveBeenCalled() + }) + + it('should return undefined from findMessageInTree when not found', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSwitchSibling('nonexistent-id', {}) + }) + + expect(mockSseGet).not.toHaveBeenCalled() + }) + + it('should search children recursively in findMessageInTree', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation(() => {}) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'parent' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-parent', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSend({ + query: 'child', + parent_message_id: 'msg-parent', + }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + conversation_id: null, + message_id: 'msg-child', + }) + }) + + act(() => { + sendCallbacks.onHumanInputRequired({ + data: { node_id: 'h-child', form_token: 'ft-c' }, + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSwitchSibling('msg-child', {}) + }) + + expect(mockSseGet).toHaveBeenCalled() + }) +}) + +describe('useChat – handleSubmitHumanInputForm', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockSubmitHumanInputForm.mockResolvedValue({}) + }) + + it('should call submitHumanInputForm with token and data', async () => { + const { result } = renderHook(() => useChat({})) + + await act(async () => { + await result.current.handleSubmitHumanInputForm('token-123', { field: 'value' }) + }) + + expect(mockSubmitHumanInputForm).toHaveBeenCalledWith('token-123', { field: 'value' }) + }) +}) + +describe('useChat – getHumanInputNodeData', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockGetNodes.mockReturnValue([]) + }) + + it('should return the custom node matching the given nodeID', () => { + const mockNode = { id: 'node-1', type: 'custom', data: { title: 'Human Input' } } + mockGetNodes.mockReturnValue([ + mockNode, + { id: 'node-2', type: 'custom', data: { title: 'Other' } }, + ]) + + const { result } = renderHook(() => useChat({})) + const node = result.current.getHumanInputNodeData('node-1') + expect(node).toEqual(mockNode) + }) + + it('should return undefined when no matching node', () => { + mockGetNodes.mockReturnValue([{ id: 'node-2', type: 'custom', data: {} }]) + + const { result } = renderHook(() => useChat({})) + const node = result.current.getHumanInputNodeData('nonexistent') + expect(node).toBeUndefined() + }) + + it('should filter out non-custom nodes', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', type: 'default', data: {} }, + { id: 'node-1', type: 'custom', data: { found: true } }, + ]) + + const { result } = renderHook(() => useChat({})) + const node = result.current.getHumanInputNodeData('node-1') + expect(node).toEqual({ id: 'node-1', type: 'custom', data: { found: true } }) + }) +}) + +describe('useChat – conversationId and setTargetMessageId', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + }) + + it('should initially be an empty string', () => { + const { result } = renderHook(() => useChat({})) + expect(result.current.conversationId).toBe('') + }) + + it('setTargetMessageId should change chatList thread path', () => { + const prevChatTree: ChatItemInTree[] = [ + { + id: 'q1', + content: 'question 1', + isAnswer: false, + children: [ + { + id: 'a1', + content: 'answer 1', + isAnswer: true, + children: [ + { + id: 'q2-branch-a', + content: 'branch A question', + isAnswer: false, + children: [ + { id: 'a2-branch-a', content: 'branch A answer', isAnswer: true, children: [] }, + ], + }, + { + id: 'q2-branch-b', + content: 'branch B question', + isAnswer: false, + children: [ + { id: 'a2-branch-b', content: 'branch B answer', isAnswer: true, children: [] }, + ], + }, + ], + }, + ], + }, + ] + + const { result } = renderHook(() => useChat({}, undefined, prevChatTree)) + + const defaultList = result.current.chatList + expect(defaultList.some(item => item.id === 'a1')).toBe(true) + + act(() => { + result.current.setTargetMessageId('a2-branch-a') + }) + + const listA = result.current.chatList + expect(listA.some(item => item.id === 'a2-branch-a')).toBe(true) + expect(listA.some(item => item.id === 'a2-branch-b')).toBe(false) + + act(() => { + result.current.setTargetMessageId('a2-branch-b') + }) + + const listB = result.current.chatList + expect(listB.some(item => item.id === 'a2-branch-b')).toBe(true) + expect(listB.some(item => item.id === 'a2-branch-a')).toBe(false) + }) +}) + +describe('useChat – updateCurrentQAOnTree with parent_message_id', () => { + let capturedCallbacks: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + }) + + it('should handle follow-up message with parent_message_id', async () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'first' }, {}) + }) + + const firstCallbacks = capturedCallbacks + + act(() => { + firstCallbacks.onData('answer1', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + await act(async () => { + await firstCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSend({ + query: 'follow up', + parent_message_id: 'msg-1', + }, {}) + }) + + expect(mockHandleRun).toHaveBeenCalledTimes(2) + expect(result.current.chatList.length).toBeGreaterThan(0) + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/opening-statement.spec.ts similarity index 63% rename from web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts rename to web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/opening-statement.spec.ts index 397e1c22d6..e985d6790d 100644 --- a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/opening-statement.spec.ts @@ -1,50 +1,73 @@ +/* eslint-disable ts/no-explicit-any */ import type { ChatItemInTree } from '@/app/components/base/chat/types' import { renderHook } from '@testing-library/react' -import { useChat } from '../hooks' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null vi.mock('@/service/base', () => ({ - sseGet: vi.fn(), + sseGet: (...args: any[]) => mockSseGet(...args), })) vi.mock('@/service/use-workflow', () => ({ - useInvalidAllLastRun: () => vi.fn(), + useInvalidAllLastRun: () => mockInvalidAllLastRun, })) vi.mock('@/service/workflow', () => ({ - submitHumanInputForm: vi.fn(), + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), })) vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: vi.fn() }), + useToastContext: () => ({ notify: mockNotify }), })) vi.mock('reactflow', () => ({ - useStoreApi: () => ({ getState: () => ({}) }), + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), })) -vi.mock('../../../hooks', () => ({ - useWorkflowRun: () => ({ handleRun: vi.fn() }), - useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: vi.fn() }), +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), })) -vi.mock('../../../hooks-store', () => ({ +vi.mock('../../../../hooks-store', () => ({ useHooksStore: () => null, })) -vi.mock('../../../store', () => ({ +vi.mock('../../../../store', () => ({ useWorkflowStore: () => ({ getState: () => ({ - setIterTimes: vi.fn(), - setLoopTimes: vi.fn(), + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, inputs: {}, + workflowRunningData: mockWorkflowRunningData, }), }), useStore: () => vi.fn(), })) +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + describe('workflow debug useChat – opening statement stability', () => { beforeEach(() => { - vi.clearAllMocks() + resetMocksAndWorkflowState() }) it('should return empty chatList when config has no opening_statement', () => { @@ -59,7 +82,6 @@ describe('workflow debug useChat – opening statement stability', () => { it('should use stable id "opening-statement" instead of Date.now()', () => { const config = { opening_statement: 'Welcome!' } - const { result } = renderHook(() => useChat(config)) expect(result.current.chatList[0].id).toBe('opening-statement') }) @@ -132,4 +154,21 @@ describe('workflow debug useChat – opening statement stability', () => { const openerAfter = result.current.chatList[0] expect(openerAfter).toBe(openerBefore) }) + + it('should include suggestedQuestions in opening statement when config has them', () => { + const config = { + opening_statement: 'Welcome!', + suggested_questions: ['How are you?', 'What can you do?'], + } + const { result } = renderHook(() => useChat(config)) + const opener = result.current.chatList[0] + expect(opener.suggestedQuestions).toEqual(['How are you?', 'What can you do?']) + }) + + it('should not include suggestedQuestions when config has none', () => { + const config = { opening_statement: 'Welcome!' } + const { result } = renderHook(() => useChat(config)) + const opener = result.current.chatList[0] + expect(opener.suggestedQuestions).toBeUndefined() + }) }) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts new file mode 100644 index 0000000000..073adc59de --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts @@ -0,0 +1,914 @@ +/* eslint-disable ts/no-explicit-any */ +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleSend SSE callbacks', () => { + let capturedCallbacks: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + }) + + function setupAndSend(config: any = {}) { + const hook = renderHook(() => useChat(config)) + act(() => { + hook.result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: vi.fn().mockResolvedValue({ data: ['q1'] }), + }) + }) + return hook + } + + function startWorkflow(overrides: Record = {}) { + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: null, + ...overrides, + }) + }) + } + + function startNode(nodeId: string, traceId: string, extra: Record = {}) { + act(() => { + capturedCallbacks.onNodeStarted({ + data: { node_id: nodeId, id: traceId, ...extra }, + }) + }) + } + + describe('onData', () => { + it('should append message content', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hello', true, { + conversationId: 'conv-1', + messageId: 'msg-1', + taskId: 'task-1', + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.content).toContain('Hello') + }) + + it('should set response id from messageId on first call', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hi', true, { + conversationId: 'conv-1', + messageId: 'msg-123', + taskId: 'task-1', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-123') + expect(answer).toBeDefined() + }) + + it('should set conversationId on first message with newConversationId', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hi', true, { + conversationId: 'new-conv-id', + messageId: 'msg-1', + taskId: 'task-1', + }) + }) + + expect(result.current.conversationId).toBe('new-conv-id') + }) + + it('should not set conversationId when isFirstMessage is false', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hi', false, { + conversationId: 'conv-should-not-set', + messageId: 'msg-1', + taskId: 'task-1', + }) + }) + + expect(result.current.conversationId).toBe('') + }) + + it('should not update hasSetResponseId when messageId is empty', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('msg1', true, { + conversationId: '', + messageId: '', + taskId: 'task-1', + }) + }) + + act(() => { + capturedCallbacks.onData('msg2', false, { + conversationId: '', + messageId: 'late-id', + taskId: 'task-1', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'late-id') + expect(answer).toBeDefined() + }) + + it('should only set hasSetResponseId once', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('msg1', true, { + conversationId: 'c1', + messageId: 'msg-first', + taskId: 'task-1', + }) + }) + + act(() => { + capturedCallbacks.onData('msg2', false, { + conversationId: 'c1', + messageId: 'msg-second', + taskId: 'task-1', + }) + }) + + const question = result.current.chatList.find(item => !item.isAnswer) + expect(question!.id).toBe('question-msg-first') + }) + }) + + describe('onCompleted', () => { + it('should set isResponding to false', async () => { + const { result } = setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should call fetchInspectVars and invalidAllLastRun when not paused', async () => { + setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + expect(mockFetchInspectVars).toHaveBeenCalledWith({}) + expect(mockInvalidAllLastRun).toHaveBeenCalled() + }) + + it('should not call fetchInspectVars when workflow is paused', async () => { + mockWorkflowRunningData = { result: { status: 'paused' } } + setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + expect(mockFetchInspectVars).not.toHaveBeenCalled() + }) + + it('should set error content on response item when hasError with errorMessage', async () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('partial', true, { + conversationId: 'c1', + messageId: 'msg-err', + taskId: 't1', + }) + }) + + await act(async () => { + await capturedCallbacks.onCompleted(true, 'Something went wrong') + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-err') + expect(answer!.content).toBe('Something went wrong') + expect(answer!.isError).toBe(true) + }) + + it('should not set error content when hasError is true but errorMessage is empty', async () => { + const { result } = setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(true) + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should fetch suggested questions when enabled and invoke abort controller callback', async () => { + const mockGetSuggested = vi.fn().mockImplementation((_id: string, getAbortCtrl: any) => { + getAbortCtrl(new AbortController()) + return Promise.resolve({ data: ['suggestion1'] }) + }) + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + + expect(mockGetSuggested).toHaveBeenCalled() + }) + + it('should set suggestedQuestions to empty array when fetch fails', async () => { + const mockGetSuggested = vi.fn().mockRejectedValue(new Error('fail')) + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + + expect(hook.result.current.suggestedQuestions).toEqual([]) + }) + }) + + describe('onError', () => { + it('should set isResponding to false', () => { + const { result } = setupAndSend() + act(() => { + capturedCallbacks.onError() + }) + expect(result.current.isResponding).toBe(false) + }) + }) + + describe('onMessageEnd', () => { + it('should update citation and files', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('response', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + act(() => { + capturedCallbacks.onMessageEnd({ + metadata: { retriever_resources: [{ id: 'r1' }] }, + files: [], + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-1') + expect(answer!.citation).toEqual([{ id: 'r1' }]) + }) + + it('should default citation to empty array when no retriever_resources', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('response', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + act(() => { + capturedCallbacks.onMessageEnd({ metadata: {}, files: [] }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-1') + expect(answer!.citation).toEqual([]) + }) + }) + + describe('onMessageReplace', () => { + it('should replace answer content on responseItem', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('old', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + act(() => { + capturedCallbacks.onMessageReplace({ answer: 'replaced' }) + }) + + act(() => { + capturedCallbacks.onMessageEnd({ metadata: {}, files: [] }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-1') + expect(answer!.content).toBe('replaced') + }) + }) + + describe('onWorkflowStarted', () => { + it('should create workflow process with Running status', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: 'conv-1', + message_id: 'msg-1', + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('running') + expect(answer!.workflowProcess!.tracing).toEqual([]) + }) + + it('should set conversationId when provided', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: 'from-workflow', + message_id: null, + }) + }) + + expect(result.current.conversationId).toBe('from-workflow') + }) + + it('should not override existing conversationId when conversation_id is null', () => { + const { result } = setupAndSend() + startWorkflow() + expect(result.current.conversationId).toBe('') + }) + + it('should resume existing workflow process when tracing exists', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('n1', 'trace-1') + startWorkflow({ workflow_run_id: 'wfr-2', task_id: 'task-2' }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('running') + expect(answer!.workflowProcess!.tracing.length).toBe(1) + }) + + it('should replace placeholder answer id with real message_id from server', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'wf-msg-id', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'wf-msg-id') + expect(answer).toBeDefined() + }) + }) + + describe('onWorkflowFinished', () => { + it('should update workflow process status', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onWorkflowFinished({ data: { status: 'succeeded' } }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('succeeded') + }) + }) + + describe('onIterationStart / onIterationFinish', () => { + it('should push tracing entry on start', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onIterationStart({ + data: { id: 'iter-1', node_id: 'n-iter' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('iter-1') + expect(trace.node_id).toBe('n-iter') + expect(trace.status).toBe('running') + }) + + it('should update matching tracing on finish', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onIterationStart({ + data: { id: 'iter-1', node_id: 'n-iter' }, + }) + }) + + act(() => { + capturedCallbacks.onIterationFinish({ + data: { id: 'iter-1', node_id: 'n-iter', output: 'done' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.id === 'iter-1') + expect(trace).toBeDefined() + expect(trace!.node_id).toBe('n-iter') + expect((trace as any).output).toBe('done') + }) + + it('should not update tracing on finish when id does not match', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onIterationStart({ + data: { id: 'iter-1', node_id: 'n-iter' }, + }) + }) + + act(() => { + capturedCallbacks.onIterationFinish({ + data: { id: 'iter-nonexistent', node_id: 'n-other' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect((answer!.workflowProcess!.tracing[0] as any).output).toBeUndefined() + }) + }) + + describe('onLoopStart / onLoopFinish', () => { + it('should push tracing entry on start', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onLoopStart({ + data: { id: 'loop-1', node_id: 'n-loop' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('loop-1') + expect(trace.node_id).toBe('n-loop') + expect(trace.status).toBe('running') + }) + + it('should update matching tracing on finish', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onLoopStart({ + data: { id: 'loop-1', node_id: 'n-loop' }, + }) + }) + + act(() => { + capturedCallbacks.onLoopFinish({ + data: { id: 'loop-1', node_id: 'n-loop', output: 'done' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('loop-1') + expect(trace.node_id).toBe('n-loop') + expect((trace as any).output).toBe('done') + }) + + it('should not update tracing on finish when id does not match', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onLoopStart({ + data: { id: 'loop-1', node_id: 'n-loop' }, + }) + }) + + act(() => { + capturedCallbacks.onLoopFinish({ + data: { id: 'loop-nonexistent', node_id: 'n-other' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect((answer!.workflowProcess!.tracing[0] as any).output).toBeUndefined() + }) + }) + + describe('onNodeStarted / onNodeRetry / onNodeFinished', () => { + it('should add new tracing entry', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1') + expect(trace.node_id).toBe('node-1') + expect(trace.status).toBe('running') + }) + + it('should update existing tracing entry with same node_id', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + startNode('node-1', 'trace-1-v2') + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1-v2') + expect(trace.node_id).toBe('node-1') + expect(trace.status).toBe('running') + }) + + it('should push retry data to tracing', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onNodeRetry({ + data: { node_id: 'node-1', id: 'retry-1', retry_index: 1 }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('retry-1') + expect(trace.node_id).toBe('node-1') + expect((trace as any).retry_index).toBe(1) + }) + + it('should update tracing entry on finish by id', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + + act(() => { + capturedCallbacks.onNodeFinished({ + data: { node_id: 'node-1', id: 'trace-1', status: 'succeeded', outputs: { text: 'done' } }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1') + expect(trace.status).toBe('succeeded') + expect((trace as any).outputs).toEqual({ text: 'done' }) + }) + + it('should not update tracing on finish when id does not match', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + + act(() => { + capturedCallbacks.onNodeFinished({ + data: { node_id: 'node-x', id: 'trace-x', status: 'succeeded' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1') + expect(trace.status).toBe('running') + }) + }) + + describe('onAgentLog', () => { + function setupWithNode() { + const hook = setupAndSend() + startWorkflow() + return hook + } + + it('should create execution_metadata.agent_log when no execution_metadata exists', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent') + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-1', content: 'init' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(1) + }) + + it('should create agent_log array when execution_metadata exists but no agent_log', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent', { execution_metadata: { parallel_id: 'p1' } }) + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-1', content: 'step1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(1) + }) + + it('should update existing agent_log entry by message_id', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent', { + execution_metadata: { agent_log: [{ message_id: 'log-1', content: 'v1' }] }, + }) + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-1', content: 'v2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(1) + expect((agentTrace!.execution_metadata!.agent_log as any[])[0].content).toBe('v2') + }) + + it('should push new agent_log entry when message_id does not match', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent', { + execution_metadata: { agent_log: [{ message_id: 'log-1', content: 'v1' }] }, + }) + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-2', content: 'new' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(2) + }) + + it('should not crash when node_id is not found in tracing', () => { + setupWithNode() + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'nonexistent-node', message_id: 'log-1', content: 'noop' }, + }) + }) + }) + }) + + describe('onHumanInputRequired', () => { + it('should add form data to humanInputFormDataList', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('human-node', 'trace-human') + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(1) + expect(answer!.humanInputFormDataList![0].node_id).toBe('human-node') + expect((answer!.humanInputFormDataList![0] as any).form_token).toBe('token-1') + }) + + it('should update existing form for same node_id', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('human-node', 'trace-human') + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(1) + expect((answer!.humanInputFormDataList![0] as any).form_token).toBe('token-2') + }) + + it('should push new form data for different node_id', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node-1', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node-2', form_token: 'token-2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(2) + expect(answer!.humanInputFormDataList![0].node_id).toBe('human-node-1') + expect(answer!.humanInputFormDataList![1].node_id).toBe('human-node-2') + }) + + it('should set tracing node status to Paused when tracing index found', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('human-node', 'trace-human') + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'human-node') + expect(trace!.status).toBe('paused') + }) + }) + + describe('onHumanInputFormFilled', () => { + it('should remove form and add to filled list', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputFormFilled({ + data: { node_id: 'human-node', form_data: { answer: 'yes' } }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(0) + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + expect(answer!.humanInputFilledFormDataList![0].node_id).toBe('human-node') + expect((answer!.humanInputFilledFormDataList![0] as any).form_data).toEqual({ answer: 'yes' }) + }) + }) + + describe('onHumanInputFormTimeout', () => { + it('should update expiration_time on form data', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputFormTimeout({ + data: { node_id: 'human-node', expiration_time: '2025-01-01T00:00:00Z' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const form = answer!.humanInputFormDataList!.find((f: any) => f.node_id === 'human-node') + expect(form!.expiration_time).toBe('2025-01-01T00:00:00Z') + }) + }) + + describe('onWorkflowPaused', () => { + it('should set status to Paused', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onWorkflowPaused({ data: {} }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('paused') + }) + }) +}) From 7fe25f136569ec3796b9350d2f748c3e4e207396 Mon Sep 17 00:00:00 2001 From: Zhanyuan Guo <2364319479@qq.com> Date: Tue, 24 Mar 2026 15:08:55 +0800 Subject: [PATCH 20/32] fix(rate_limit): flush redis cache when __init__ is triggered by changing max_active_requests (#33830) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../app/features/rate_limiting/rate_limit.py | 12 +++++-- .../features/rate_limiting/test_rate_limit.py | 34 +++++++++++++++++-- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..e0f1759e5e 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c72..538b130cac 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" From 1674f8c2fb5ac6bc6e2468dfebbf8ae9096d4221 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Tue, 24 Mar 2026 15:10:05 +0800 Subject: [PATCH 21/32] fix: fix omitted app icon_type updates (#33988) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/app.py | 8 +- api/services/app_service.py | 10 +- .../services/test_app_service.py | 105 +++++++++++++++++- .../controllers/console/app/test_app_apis.py | 49 +++++++- .../unit_tests/services/test_app_service.py | 76 ++++++++++++- 5 files changed, 239 insertions(+), 9 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6..7e41260eeb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..69c7c0c95a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -241,7 +241,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +257,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..9ca8729b77 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index beb8ff55a5..1d1e119fd6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py index bff8dc92c6..95fc28b1e7 100644 --- a/api/tests/unit_tests/services/test_app_service.py +++ b/api/tests/unit_tests/services/test_app_service.py @@ -9,7 +9,7 @@ import pytest from core.errors.error import ProviderTokenNotInitError from models import Account, Tenant -from models.model import App, AppMode +from models.model import App, AppMode, IconType from services.app_service import AppService @@ -411,6 +411,7 @@ class TestAppServiceGetAndUpdate: # Assert assert updated is app + assert updated.icon_type == IconType.IMAGE assert renamed is app assert iconed is app assert site_same is app @@ -419,6 +420,79 @@ class TestAppServiceGetAndUpdate: assert api_changed is app assert mock_db.session.commit.call_count >= 5 + def test_update_app_should_preserve_icon_type_when_not_provided(self, service: AppService) -> None: + """Test update_app keeps the existing icon_type when the payload omits it.""" + # Arrange + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type=IconType.EMOJI, + icon="a", + icon_background="#111", + use_icon_as_answer_icon=False, + max_active_requests=1, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": None, + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + patch("services.app_service.naive_utc_now", return_value="now"), + ): + # Act + updated = service.update_app(app, args) + + # Assert + assert updated is app + assert updated.icon_type == IconType.EMOJI + mock_db.session.commit.assert_called_once() + + def test_update_app_should_reject_empty_icon_type(self, service: AppService) -> None: + """Test update_app rejects an explicit empty icon_type.""" + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type=IconType.EMOJI, + icon="a", + icon_background="#111", + use_icon_as_answer_icon=False, + max_active_requests=1, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": "", + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + ): + with pytest.raises(ValueError): + service.update_app(app, args) + + mock_db.session.commit.assert_not_called() + class TestAppServiceDeleteAndMeta: """Test suite for delete and metadata methods.""" From 0c3d11f920000109c3eb54823d10a3feb493d61f Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 24 Mar 2026 15:29:42 +0800 Subject: [PATCH 22/32] refactor: lazy load large modules (#33888) --- web/app/components/apps/index.tsx | 8 +++-- web/app/components/apps/list.tsx | 14 ++++---- .../base/amplitude/AmplitudeProvider.tsx | 9 ++--- .../__tests__/AmplitudeProvider.spec.tsx | 30 +++++++--------- .../base/amplitude/__tests__/index.spec.ts | 32 ----------------- .../base/amplitude/__tests__/utils.spec.ts | 6 ++-- web/app/components/base/amplitude/index.ts | 2 +- .../amplitude/lazy-amplitude-provider.tsx | 11 ++++++ web/app/components/base/amplitude/utils.ts | 10 +++--- .../components/devtools/agentation-loader.tsx | 13 +++++++ .../account-dropdown/__tests__/index.spec.tsx | 3 ++ web/app/components/header/app-nav/index.tsx | 8 +++-- .../components/lazy-sentry-initializer.tsx | 16 +++++++++ web/app/components/sentry-initializer.tsx | 7 ++-- web/app/layout.tsx | 34 +++++++++---------- web/config/index.ts | 2 ++ web/eslint-suppressions.json | 5 --- 17 files changed, 104 insertions(+), 106 deletions(-) delete mode 100644 web/app/components/base/amplitude/__tests__/index.spec.ts create mode 100644 web/app/components/base/amplitude/lazy-amplitude-provider.tsx create mode 100644 web/app/components/devtools/agentation-loader.tsx create mode 100644 web/app/components/lazy-sentry-initializer.tsx diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index dce9de190d..b6ca60bd7b 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -8,12 +8,14 @@ import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' +import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' -import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' -import CreateAppModal from '../explore/create-app-modal' -import TryApp from '../explore/try-app' import List from './list' +const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) +const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) + const Apps = () => { const { t } = useTranslation() diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 0d52bd468c..2ef344f816 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -5,11 +5,11 @@ import { useDebounceFn } from 'ahooks' import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagFilter from '@/app/components/base/tag-management/filter' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' -import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -205,12 +205,12 @@ const List: FC = ({ options={options} />
- + { - return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY -} - // Map URL pathname to English page name for consistent Amplitude tracking const getEnglishPageName = (pathname: string): string => { // Remove leading slash and get the first segment @@ -59,7 +54,7 @@ const AmplitudeProvider: FC = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return // Initialize Amplitude diff --git a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx index b30da72091..5835634eb7 100644 --- a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx +++ b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx @@ -2,14 +2,24 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import { render } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' +import AmplitudeProvider from '../AmplitudeProvider' const mockConfig = vi.hoisted(() => ({ AMPLITUDE_API_KEY: 'test-api-key', IS_CLOUD_EDITION: true, })) -vi.mock('@/config', () => mockConfig) +vi.mock('@/config', () => ({ + get AMPLITUDE_API_KEY() { + return mockConfig.AMPLITUDE_API_KEY + }, + get IS_CLOUD_EDITION() { + return mockConfig.IS_CLOUD_EDITION + }, + get isAmplitudeEnabled() { + return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY + }, +})) vi.mock('@amplitude/analytics-browser', () => ({ init: vi.fn(), @@ -27,22 +37,6 @@ describe('AmplitudeProvider', () => { mockConfig.IS_CLOUD_EDITION = true }) - describe('isAmplitudeEnabled', () => { - it('returns true when cloud edition and api key present', () => { - expect(isAmplitudeEnabled()).toBe(true) - }) - - it('returns false when cloud edition but no api key', () => { - mockConfig.AMPLITUDE_API_KEY = '' - expect(isAmplitudeEnabled()).toBe(false) - }) - - it('returns false when not cloud edition', () => { - mockConfig.IS_CLOUD_EDITION = false - expect(isAmplitudeEnabled()).toBe(false) - }) - }) - describe('Component', () => { it('initializes amplitude when enabled', () => { render() diff --git a/web/app/components/base/amplitude/__tests__/index.spec.ts b/web/app/components/base/amplitude/__tests__/index.spec.ts deleted file mode 100644 index 2d7ad6ab84..0000000000 --- a/web/app/components/base/amplitude/__tests__/index.spec.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { describe, expect, it } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' -import indexDefault, { - isAmplitudeEnabled as indexIsAmplitudeEnabled, - resetUser, - setUserId, - setUserProperties, - trackEvent, -} from '../index' -import { - resetUser as utilsResetUser, - setUserId as utilsSetUserId, - setUserProperties as utilsSetUserProperties, - trackEvent as utilsTrackEvent, -} from '../utils' - -describe('Amplitude index exports', () => { - it('exports AmplitudeProvider as default', () => { - expect(indexDefault).toBe(AmplitudeProvider) - }) - - it('exports isAmplitudeEnabled', () => { - expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) - }) - - it('exports utils', () => { - expect(resetUser).toBe(utilsResetUser) - expect(setUserId).toBe(utilsSetUserId) - expect(setUserProperties).toBe(utilsSetUserProperties) - expect(trackEvent).toBe(utilsTrackEvent) - }) -}) diff --git a/web/app/components/base/amplitude/__tests__/utils.spec.ts b/web/app/components/base/amplitude/__tests__/utils.spec.ts index ecbc57e387..f1ff5db1e3 100644 --- a/web/app/components/base/amplitude/__tests__/utils.spec.ts +++ b/web/app/components/base/amplitude/__tests__/utils.spec.ts @@ -20,8 +20,10 @@ const MockIdentify = vi.hoisted(() => }, ) -vi.mock('../AmplitudeProvider', () => ({ - isAmplitudeEnabled: () => mockState.enabled, +vi.mock('@/config', () => ({ + get isAmplitudeEnabled() { + return mockState.enabled + }, })) vi.mock('@amplitude/analytics-browser', () => ({ diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts index acc792339e..44cbf728e2 100644 --- a/web/app/components/base/amplitude/index.ts +++ b/web/app/components/base/amplitude/index.ts @@ -1,2 +1,2 @@ -export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { default } from './lazy-amplitude-provider' export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/lazy-amplitude-provider.tsx b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx new file mode 100644 index 0000000000..5dfa0e7b53 --- /dev/null +++ b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx @@ -0,0 +1,11 @@ +'use client' + +import type { FC } from 'react' +import type { IAmplitudeProps } from './AmplitudeProvider' +import dynamic from '@/next/dynamic' + +const AmplitudeProvider = dynamic(() => import('./AmplitudeProvider'), { ssr: false }) + +const LazyAmplitudeProvider: FC = props => + +export default LazyAmplitudeProvider diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts index 57b96243ec..8faa8e852e 100644 --- a/web/app/components/base/amplitude/utils.ts +++ b/web/app/components/base/amplitude/utils.ts @@ -1,5 +1,5 @@ import * as amplitude from '@amplitude/analytics-browser' -import { isAmplitudeEnabled } from './AmplitudeProvider' +import { isAmplitudeEnabled } from '@/config' /** * Track custom event @@ -7,7 +7,7 @@ import { isAmplitudeEnabled } from './AmplitudeProvider' * @param eventProperties Event properties (optional) */ export const trackEvent = (eventName: string, eventProperties?: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.track(eventName, eventProperties) } @@ -17,7 +17,7 @@ export const trackEvent = (eventName: string, eventProperties?: Record { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.setUserId(userId) } @@ -27,7 +27,7 @@ export const setUserId = (userId: string) => { * @param properties User properties */ export const setUserProperties = (properties: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return const identifyEvent = new amplitude.Identify() Object.entries(properties).forEach(([key, value]) => { @@ -40,7 +40,7 @@ export const setUserProperties = (properties: Record) => { * Reset user (e.g., when user logs out) */ export const resetUser = () => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.reset() } diff --git a/web/app/components/devtools/agentation-loader.tsx b/web/app/components/devtools/agentation-loader.tsx new file mode 100644 index 0000000000..87e1b44c87 --- /dev/null +++ b/web/app/components/devtools/agentation-loader.tsx @@ -0,0 +1,13 @@ +'use client' + +import { IS_DEV } from '@/config' +import dynamic from '@/next/dynamic' + +const Agentation = dynamic(() => import('agentation').then(module => module.Agentation), { ssr: false }) + +export function AgentationLoader() { + if (!IS_DEV) + return null + + return +} diff --git a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx index eb4d543e66..9d4226c33a 100644 --- a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx @@ -69,6 +69,7 @@ vi.mock('@/context/i18n', () => ({ const { mockConfig, mockEnv } = vi.hoisted(() => ({ mockConfig: { IS_CLOUD_EDITION: false, + AMPLITUDE_API_KEY: '', ZENDESK_WIDGET_KEY: '', SUPPORT_EMAIL_ADDRESS: '', }, @@ -80,6 +81,8 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ })) vi.mock('@/config', () => ({ get IS_CLOUD_EDITION() { return mockConfig.IS_CLOUD_EDITION }, + get AMPLITUDE_API_KEY() { return mockConfig.AMPLITUDE_API_KEY }, + get isAmplitudeEnabled() { return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY }, get ZENDESK_WIDGET_KEY() { return mockConfig.ZENDESK_WIDGET_KEY }, get SUPPORT_EMAIL_ADDRESS() { return mockConfig.SUPPORT_EMAIL_ADDRESS }, IS_DEV: false, diff --git a/web/app/components/header/app-nav/index.tsx b/web/app/components/header/app-nav/index.tsx index 214b7612bb..54ddf75711 100644 --- a/web/app/components/header/app-nav/index.tsx +++ b/web/app/components/header/app-nav/index.tsx @@ -9,16 +9,18 @@ import { flatten } from 'es-toolkit/compat' import { produce } from 'immer' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import CreateAppTemplateDialog from '@/app/components/app/create-app-dialog' -import CreateAppModal from '@/app/components/app/create-app-modal' -import CreateFromDSLModal from '@/app/components/app/create-from-dsl-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' +import dynamic from '@/next/dynamic' import { useParams } from '@/next/navigation' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum } from '@/types/app' import Nav from '../nav' +const CreateAppTemplateDialog = dynamic(() => import('@/app/components/app/create-app-dialog'), { ssr: false }) +const CreateAppModal = dynamic(() => import('@/app/components/app/create-app-modal'), { ssr: false }) +const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), { ssr: false }) + const AppNav = () => { const { t } = useTranslation() const { appId } = useParams() diff --git a/web/app/components/lazy-sentry-initializer.tsx b/web/app/components/lazy-sentry-initializer.tsx new file mode 100644 index 0000000000..8c29ca4f9a --- /dev/null +++ b/web/app/components/lazy-sentry-initializer.tsx @@ -0,0 +1,16 @@ +'use client' + +import { IS_DEV } from '@/config' +import { env } from '@/env' +import dynamic from '@/next/dynamic' + +const SentryInitializer = dynamic(() => import('./sentry-initializer'), { ssr: false }) + +const LazySentryInitializer = () => { + if (IS_DEV || !env.NEXT_PUBLIC_SENTRY_DSN) + return null + + return +} + +export default LazySentryInitializer diff --git a/web/app/components/sentry-initializer.tsx b/web/app/components/sentry-initializer.tsx index 8a7286f908..00c3af37c1 100644 --- a/web/app/components/sentry-initializer.tsx +++ b/web/app/components/sentry-initializer.tsx @@ -2,13 +2,10 @@ import * as Sentry from '@sentry/react' import { useEffect } from 'react' - import { IS_DEV } from '@/config' import { env } from '@/env' -const SentryInitializer = ({ - children, -}: { children: React.ReactElement }) => { +const SentryInitializer = () => { useEffect(() => { const SENTRY_DSN = env.NEXT_PUBLIC_SENTRY_DSN if (!IS_DEV && SENTRY_DSN) { @@ -24,7 +21,7 @@ const SentryInitializer = ({ }) } }, []) - return children + return null } export default SentryInitializer diff --git a/web/app/layout.tsx b/web/app/layout.tsx index be51c76f2e..1cf1bb0d94 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -1,9 +1,7 @@ import type { Viewport } from '@/next' -import { Agentation } from 'agentation' import { Provider as JotaiProvider } from 'jotai/react' import { ThemeProvider } from 'next-themes' import { NuqsAdapter } from 'nuqs/adapters/next/app' -import { IS_DEV } from '@/config' import GlobalPublicStoreProvider from '@/context/global-public-context' import { TanstackQueryInitializer } from '@/context/query-client' import { getDatasetMap } from '@/env' @@ -12,9 +10,10 @@ import { ToastProvider } from './components/base/toast' import { ToastHost } from './components/base/ui/toast' import { TooltipProvider } from './components/base/ui/tooltip' import BrowserInitializer from './components/browser-initializer' +import { AgentationLoader } from './components/devtools/agentation-loader' import { ReactScanLoader } from './components/devtools/react-scan/loader' +import LazySentryInitializer from './components/lazy-sentry-initializer' import { I18nServerProvider } from './components/provider/i18n-server' -import SentryInitializer from './components/sentry-initializer' import RoutePrefixHandle from './routePrefixHandle' import './styles/globals.css' import './styles/markdown.scss' @@ -57,6 +56,7 @@ const LocaleLayout = async ({ className="h-full select-auto" {...datasetMap} > +
- - - - - - - - {children} - - - - - - + + + + + + + {children} + + + + + - {IS_DEV && } +
diff --git a/web/config/index.ts b/web/config/index.ts index 3f7d26c623..eed914726c 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -42,6 +42,8 @@ export const AMPLITUDE_API_KEY = getStringConfig( '', ) +export const isAmplitudeEnabled = IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY + export const IS_DEV = process.env.NODE_ENV === 'development' export const IS_PROD = process.env.NODE_ENV === 'production' diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 0f69e6ce33..c02d302f09 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -1501,11 +1501,6 @@ "count": 2 } }, - "app/components/base/amplitude/AmplitudeProvider.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/base/amplitude/utils.ts": { "ts/no-explicit-any": { "count": 2 From d14635625c88d07cb83690dfefb025c8a6d7675f Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:18:36 +0800 Subject: [PATCH 23/32] feat(web): refactor pricing modal scrolling and accessibility (#34011) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../billing/pricing/__tests__/header.spec.tsx | 18 +++-- .../billing/pricing/__tests__/index.spec.tsx | 1 + web/app/components/billing/pricing/footer.tsx | 3 +- web/app/components/billing/pricing/header.tsx | 12 +-- web/app/components/billing/pricing/index.tsx | 79 +++++++++++++------ 5 files changed, 76 insertions(+), 37 deletions(-) diff --git a/web/app/components/billing/pricing/__tests__/header.spec.tsx b/web/app/components/billing/pricing/__tests__/header.spec.tsx index 0aadc3b0ce..cb8991ff42 100644 --- a/web/app/components/billing/pricing/__tests__/header.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/header.spec.tsx @@ -1,12 +1,14 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' -import { Dialog } from '@/app/components/base/ui/dialog' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import Header from '../header' function renderHeader(onClose: () => void) { return render( -
+ +
+
, ) } @@ -24,7 +26,7 @@ describe('Header', () => { expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) @@ -33,7 +35,7 @@ describe('Header', () => { const handleClose = vi.fn() renderHeader(handleClose) - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' })) expect(handleClose).toHaveBeenCalledTimes(1) }) @@ -41,11 +43,11 @@ describe('Header', () => { describe('Edge Cases', () => { it('should render structural elements with translation keys', () => { - const { container } = renderHeader(vi.fn()) + renderHeader(vi.fn()) - expect(container.querySelector('span')).toBeInTheDocument() - expect(container.querySelector('p')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) }) diff --git a/web/app/components/billing/pricing/__tests__/index.spec.tsx b/web/app/components/billing/pricing/__tests__/index.spec.tsx index 36848cd463..a8d0a4329e 100644 --- a/web/app/components/billing/pricing/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/index.spec.tsx @@ -68,6 +68,7 @@ describe('Pricing', () => { it('should render pricing header and localized footer link', () => { render() + expect(screen.getByRole('dialog', { name: 'billing.plansCommon.title.plans' })).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByTestId('pricing-link')).toHaveAttribute('href', 'https://dify.ai/en/pricing#plans-and-features') }) diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 0d3fd965b0..1422ec1cb1 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -28,8 +28,9 @@ const Footer = ({ {t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })} diff --git a/web/app/components/billing/pricing/header.tsx b/web/app/components/billing/pricing/header.tsx index d0ffe100db..5ab1895667 100644 --- a/web/app/components/billing/pricing/header.tsx +++ b/web/app/components/billing/pricing/header.tsx @@ -1,5 +1,6 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' +import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog' import { cn } from '@/utils/classnames' import Button from '../../base/button' import DifyLogo from '../../base/logo/dify-logo' @@ -18,24 +19,25 @@ const Header = ({
-
+ - {t('plansCommon.title.plans', { ns: 'billing' })} - +
-

+ {t('plansCommon.title.description', { ns: 'billing' })} -

+ @@ -119,7 +125,6 @@ describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - isAPIKeySet = true isRulesLoading = false parameterRules = [ { @@ -233,6 +238,26 @@ describe('ModelParameterModal', () => { expect(screen.getByTestId('model-selector')).toBeInTheDocument() }) + it('should pass nodesOutputVars and availableNodes to ParameterItem', () => { + const mockNodesOutputVars = [{ nodeId: 'n1', title: 'Node', vars: [] }] + const mockAvailableNodes = [{ id: 'n1', data: { title: 'Node', type: 'llm' } }] + + render( + , + ) + + fireEvent.click(screen.getByText('Open Settings')) + + const paramEl = screen.getByTestId('param-temperature') + expect(paramEl).toHaveAttribute('data-has-nodes-output-vars', 'true') + expect(paramEl).toHaveAttribute('data-has-available-nodes', 'true') + }) + it('should support custom triggers, workflow mode, and missing default model values', async () => { render( ({ @@ -18,6 +23,29 @@ vi.mock('@/app/components/base/tag-input', () => ({ ), })) +let promptEditorOnChange: ((text: string) => void) | undefined +let capturedWorkflowNodesMap: Record | undefined + +vi.mock('@/app/components/base/prompt-editor', () => ({ + default: ({ value, onChange, workflowVariableBlock }: { + value: string + onChange: (text: string) => void + workflowVariableBlock?: { + show: boolean + variables: NodeOutPutVar[] + workflowNodesMap?: Record + } + }) => { + promptEditorOnChange = onChange + capturedWorkflowNodesMap = workflowVariableBlock?.workflowNodesMap + return ( +
+ {value} +
+ ) + }, +})) + describe('ParameterItem', () => { const createRule = (overrides: Partial = {}): ModelParameterRule => ({ name: 'temp', @@ -30,9 +58,10 @@ describe('ParameterItem', () => { beforeEach(() => { vi.clearAllMocks() + promptEditorOnChange = undefined + capturedWorkflowNodesMap = undefined }) - // Float tests it('should render float controls and clamp numeric input to max', () => { const onChange = vi.fn() render() @@ -50,7 +79,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(0.1) }) - // Int tests it('should render int controls and clamp numeric input', () => { const onChange = vi.fn() render() @@ -75,22 +103,17 @@ describe('ParameterItem', () => { it('should render int input without slider if min or max is missing', () => { render() expect(screen.queryByRole('slider')).not.toBeInTheDocument() - // No max -> precision step expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0') }) - // Slider events (uses generic value mock for slider) it('should handle slide change and clamp values', () => { const onChange = vi.fn() render() - // Test that the actual slider triggers the onChange logic correctly - // The implementation of Slider uses onChange(val) directly via the mock fireEvent.click(screen.getByTestId('slider-btn')) expect(onChange).toHaveBeenCalledWith(2) }) - // Text & String tests it('should render exact string input and propagate text changes', () => { const onChange = vi.fn() render() @@ -109,21 +132,17 @@ describe('ParameterItem', () => { it('should render select for string with options', () => { render() - // Select renders the selected value in the trigger expect(screen.getByText('a')).toBeInTheDocument() }) - // Tag Tests it('should render tag input for tag type', () => { const onChange = vi.fn() render() expect(screen.getByText('placeholder')).toBeInTheDocument() - // Trigger mock tag input fireEvent.click(screen.getByTestId('tag-input')) expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2']) }) - // Boolean tests it('should render boolean radios and update value on click', () => { const onChange = vi.fn() render() @@ -131,7 +150,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(false) }) - // Switch tests it('should call onSwitch with current value when optional switch is toggled off', () => { const onSwitch = vi.fn() render() @@ -146,7 +164,6 @@ describe('ParameterItem', () => { expect(screen.queryByRole('switch')).not.toBeInTheDocument() }) - // Default Value Fallbacks (rendering without value) it('should use default values if value is undefined', () => { const { rerender } = render() expect(screen.getByRole('spinbutton')).toHaveValue(0.5) @@ -158,26 +175,102 @@ describe('ParameterItem', () => { expect(screen.getByText('True')).toBeInTheDocument() expect(screen.getByText('False')).toBeInTheDocument() - // Without default - rerender() // min is 0 by default in createRule + rerender() expect(screen.getByRole('spinbutton')).toHaveValue(0) }) - // Input Blur it('should reset input to actual bound value on blur', () => { render() const input = screen.getByRole('spinbutton') - // change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state) - // Actually our test fires a change so localValue = 1, then blur sets it fireEvent.change(input, { target: { value: '5' } }) fireEvent.blur(input) expect(input).toHaveValue(1) }) - // Unsupported it('should render no input for unsupported parameter type', () => { render() expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument() }) + + describe('workflow variable reference', () => { + const mockNodesOutputVars: NodeOutPutVar[] = [ + { nodeId: 'node1', title: 'LLM Node', vars: [] }, + ] + const mockAvailableNodes: Node[] = [ + { id: 'node1', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'LLM Node', type: BlockEnum.LLM } } as Node, + { id: 'start', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'Start', type: BlockEnum.Start } } as Node, + ] + + it('should build workflowNodesMap and render PromptEditor for string type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + expect(capturedWorkflowNodesMap!.node1.title).toBe('LLM Node') + expect(capturedWorkflowNodesMap!.sys.title).toBe('workflow.blocks.start') + expect(capturedWorkflowNodesMap!.sys.type).toBe(BlockEnum.Start) + + promptEditorOnChange?.('updated text') + expect(onChange).toHaveBeenCalledWith('updated text') + }) + + it('should build workflowNodesMap and render PromptEditor for text type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + + promptEditorOnChange?.('new long text') + expect(onChange).toHaveBeenCalledWith('new long text') + }) + + it('should fall back to plain input when not in workflow mode for string type', () => { + render( + , + ) + + expect(screen.queryByTestId('prompt-editor')).not.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + + it('should return undefined workflowNodesMap when not in workflow mode', () => { + render( + , + ) + + expect(capturedWorkflowNodesMap).toBeUndefined() + }) + }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index 6b4018e2aa..ccb2c67a0d 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -9,6 +9,10 @@ import type { } from '../declarations' import type { ParameterValue } from './parameter-item' import type { TriggerProps } from './trigger' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' @@ -45,6 +49,8 @@ export type ModelParameterModalProps = { readonly?: boolean isInWorkflow?: boolean scope?: string + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } const ModelParameterModal: FC = ({ @@ -61,11 +67,18 @@ const ModelParameterModal: FC = ({ renderTrigger, readonly, isInWorkflow, + nodesOutputVars, + availableNodes, }) => { const { t } = useTranslation() const [open, setOpen] = useState(false) const settingsIconRef = useRef(null) - const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId) + const { + data: parameterRulesData, + isPending, + isLoading, + } = useModelParameterRules(provider, modelId) + const isRulesLoading = isPending || isLoading const { currentProvider, currentModel, @@ -191,7 +204,7 @@ const ModelParameterModal: FC = ({ }
{ - isLoading + isRulesLoading ?
: ( [ @@ -205,6 +218,8 @@ const ModelParameterModal: FC = ({ onChange={v => handleParamChange(parameter.name, v)} onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)} isInWorkflow={isInWorkflow} + nodesOutputVars={nodesOutputVars} + availableNodes={availableNodes} /> )) ) @@ -213,7 +228,7 @@ const ModelParameterModal: FC = ({ ) } { - !parameterRules.length && isLoading && ( + !parameterRules.length && isRulesLoading && (
) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 86fb6d81d0..01e3f45371 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -1,11 +1,18 @@ import type { ModelParameterRule } from '../declarations' -import { useEffect, useRef, useState } from 'react' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' +import { useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import PromptEditor from '@/app/components/base/prompt-editor' import Radio from '@/app/components/base/radio' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import TagInput from '@/app/components/base/tag-input' import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' +import { BlockEnum } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' import { useLanguage } from '../hooks' import { isNullOrUndefined } from '../utils' @@ -18,18 +25,43 @@ type ParameterItemProps = { onChange?: (value: ParameterValue) => void onSwitch?: (checked: boolean, assignValue: ParameterValue) => void isInWorkflow?: boolean + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } + function ParameterItem({ parameterRule, value, onChange, onSwitch, isInWorkflow, + nodesOutputVars, + availableNodes = [], }: ParameterItemProps) { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) + const workflowNodesMap = useMemo(() => { + if (!isInWorkflow || !availableNodes.length) + return undefined + + return availableNodes.reduce>>((acc, node) => { + acc[node.id] = { + title: node.data.title, + type: node.data.type, + } + if (node.data.type === BlockEnum.Start) { + acc.sys = { + title: t('blocks.start', { ns: 'workflow' }), + type: BlockEnum.Start, + } + } + return acc + }, {}) + }, [availableNodes, isInWorkflow, t]) + const getDefaultValue = () => { let defaultValue: ParameterValue @@ -196,6 +228,25 @@ function ParameterItem({ } if (parameterRule.type === 'string' && !parameterRule.options?.length) { + if (isInWorkflow && nodesOutputVars) { + return ( +
+ { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return ( + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return (